@@ -21,10 +21,10 @@ def test_simple_ablation_with_mask(self) -> None:
2121 net = BasicModel_MultiLayer ()
2222 inp = torch .tensor ([[20.0 , 50.0 , 30.0 ]], requires_grad = True )
2323 self ._ablation_test_assert (
24- net ,
25- net .linear0 ,
26- inp ,
27- ([280.0 , 280.0 , 120.0 ],),
24+ model = net ,
25+ layer = net .linear0 ,
26+ test_input = inp ,
27+ expected_ablation = ([280.0 , 280.0 , 120.0 ],),
2828 layer_mask = torch .tensor ([[0 , 0 , 1 ]]),
2929 perturbations_per_eval = (1 , 2 , 3 ),
3030 attribute_to_layer_input = True ,
@@ -37,20 +37,20 @@ def test_multi_input_ablation(self) -> None:
3737 inp3 = torch .tensor ([[0.0 , 100.0 , 10.0 ], [2.0 , 10.0 , 3.0 ]])
3838 baseline = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
3939 self ._ablation_test_assert (
40- net ,
41- net .model .linear1 ,
42- (inp1 , inp2 , inp3 ),
43- [[168.0 , 992.0 , 148.0 ], [84.0 , 632.0 , 120.0 ]],
40+ model = net ,
41+ layer = net .model .linear1 ,
42+ test_input = (inp1 , inp2 , inp3 ),
43+ expected_ablation = [[168.0 , 992.0 , 148.0 ], [84.0 , 632.0 , 120.0 ]],
4444 additional_input = (1 ,),
4545 baselines = baseline ,
4646 perturbations_per_eval = (1 , 2 , 3 ),
4747 attribute_to_layer_input = True ,
4848 )
4949 self ._ablation_test_assert (
50- net ,
51- net .model .linear0 ,
52- (inp1 , inp2 , inp3 ),
53- [[168.0 , 992.0 , 148.0 ], [84.0 , 632.0 , 120.0 ]],
50+ model = net ,
51+ layer = net .model .linear0 ,
52+ test_input = (inp1 , inp2 , inp3 ),
53+ expected_ablation = [[168.0 , 992.0 , 148.0 ], [84.0 , 632.0 , 120.0 ]],
5454 additional_input = (1 ,),
5555 baselines = baseline ,
5656 perturbations_per_eval = (1 , 2 , 3 ),
@@ -65,21 +65,21 @@ def test_multi_input_ablation_with_layer_mask(self) -> None:
6565 baseline = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
6666 layer_mask = torch .tensor ([[0 , 1 , 0 ], [0 , 1 , 2 ]])
6767 self ._ablation_test_assert (
68- net ,
69- net .model .linear1 ,
70- (inp1 , inp2 , inp3 ),
71- [[316.0 , 992.0 , 316.0 ], [84.0 , 632.0 , 120.0 ]],
68+ model = net ,
69+ layer = net .model .linear1 ,
70+ test_input = (inp1 , inp2 , inp3 ),
71+ expected_ablation = [[316.0 , 992.0 , 316.0 ], [84.0 , 632.0 , 120.0 ]],
7272 additional_input = (1 ,),
7373 baselines = baseline ,
7474 perturbations_per_eval = (1 , 2 , 3 ),
7575 layer_mask = layer_mask ,
7676 attribute_to_layer_input = True ,
7777 )
7878 self ._ablation_test_assert (
79- net ,
80- net .model .linear0 ,
81- (inp1 , inp2 , inp3 ),
82- [[316.0 , 992.0 , 316.0 ], [84.0 , 632.0 , 120.0 ]],
79+ model = net ,
80+ layer = net .model .linear0 ,
81+ test_input = (inp1 , inp2 , inp3 ),
82+ expected_ablation = [[316.0 , 992.0 , 316.0 ], [84.0 , 632.0 , 120.0 ]],
8383 additional_input = (1 ,),
8484 baselines = baseline ,
8585 layer_mask = layer_mask ,
@@ -91,28 +91,32 @@ def test_simple_multi_input_conv_intermediate(self) -> None:
9191 inp = torch .arange (16 , dtype = torch .float ).view (1 , 1 , 4 , 4 )
9292 inp2 = torch .ones ((1 , 1 , 4 , 4 ))
9393 self ._ablation_test_assert (
94- net ,
95- net .relu1 ,
96- (inp , inp2 ),
97- [[[[4.0 , 13.0 ], [40.0 , 49.0 ]], [[0 , 0 ], [- 15.0 , - 24.0 ]]]],
94+ model = net ,
95+ layer = net .relu1 ,
96+ test_input = (inp , inp2 ),
97+ expected_ablation = [[[[4.0 , 13.0 ], [40.0 , 49.0 ]], [[0 , 0 ], [- 15.0 , - 24.0 ]]]],
9898 perturbations_per_eval = (1 , 2 , 4 , 8 , 12 , 16 ),
9999 )
100100 self ._ablation_test_assert (
101- net ,
102- net .relu1 ,
103- (inp , inp2 ),
104- ([[[4.0 , 13.0 ], [40.0 , 49.0 ]], [[0 , 0 ], [- 15.0 , - 24.0 ]]],),
101+ model = net ,
102+ layer = net .relu1 ,
103+ test_input = (inp , inp2 ),
104+ expected_ablation = (
105+ [[[4.0 , 13.0 ], [40.0 , 49.0 ]], [[0 , 0 ], [- 15.0 , - 24.0 ]]],
106+ ),
105107 baselines = torch .tensor (
106108 [[[- 4.0 , - 13.0 ], [- 2.0 , - 2.0 ]], [[0 , 0 ], [0.0 , 0.0 ]]]
107109 ),
108110 perturbations_per_eval = (1 , 2 , 4 , 8 , 12 , 16 ),
109111 attribute_to_layer_input = True ,
110112 )
111113 self ._ablation_test_assert (
112- net ,
113- net .relu1 ,
114- (inp , inp2 ),
115- [[[[17.0 , 17.0 ], [67.0 , 67.0 ]], [[0 , 0 ], [- 39.0 , - 39.0 ]]]],
114+ model = net ,
115+ layer = net .relu1 ,
116+ test_input = (inp , inp2 ),
117+ expected_ablation = [
118+ [[[17.0 , 17.0 ], [67.0 , 67.0 ]], [[0 , 0 ], [- 39.0 , - 39.0 ]]]
119+ ],
116120 perturbations_per_eval = (1 , 2 , 4 ),
117121 layer_mask = torch .tensor ([[[[0 , 0 ], [1 , 1 ]], [[2 , 2 ], [3 , 3 ]]]]),
118122 )
@@ -121,17 +125,20 @@ def test_simple_multi_output_ablation(self) -> None:
121125 net = BasicModel_MultiLayer (multi_input_module = True )
122126 inp = torch .tensor ([[0.0 , 6.0 , 0.0 ]])
123127 self ._ablation_test_assert (
124- net , net .multi_relu , inp , ([[0.0 , 7.0 , 7.0 , 7.0 ]], [[0.0 , 7.0 , 7.0 , 7.0 ]])
128+ model = net ,
129+ layer = net .multi_relu ,
130+ test_input = inp ,
131+ expected_ablation = ([[0.0 , 7.0 , 7.0 , 7.0 ]], [[0.0 , 7.0 , 7.0 , 7.0 ]]),
125132 )
126133
127134 def test_simple_multi_output_input_ablation (self ) -> None :
128135 net = BasicModel_MultiLayer (multi_input_module = True )
129136 inp = torch .tensor ([[0.0 , 6.0 , 0.0 ]])
130137 self ._ablation_test_assert (
131- net ,
132- net .multi_relu ,
133- inp ,
134- ([[0.0 , 7.0 , 7.0 , 7.0 ]], [[0.0 , 7.0 , 7.0 , 7.0 ]]),
138+ model = net ,
139+ layer = net .multi_relu ,
140+ test_input = inp ,
141+ expected_ablation = ([[0.0 , 7.0 , 7.0 , 7.0 ]], [[0.0 , 7.0 , 7.0 , 7.0 ]]),
135142 attribute_to_layer_input = True ,
136143 )
137144
@@ -151,7 +158,7 @@ def _ablation_test_assert(
151158 for batch_size in perturbations_per_eval :
152159 ablation = LayerFeatureAblation (model , layer )
153160 attributions = ablation .attribute (
154- test_input ,
161+ inputs = test_input ,
155162 target = target ,
156163 layer_mask = layer_mask ,
157164 additional_forward_args = additional_input ,
0 commit comments