|
112 | 112 | "model": BasicModel_MultiLayer(),
|
113 | 113 | "attribute_args": {"inputs": torch.randn(4, 3), "target": 1},
|
114 | 114 | },
|
| 115 | + { |
| 116 | + "name": "basic_single_target_cross_tensor_attributions", |
| 117 | + "algorithms": [ |
| 118 | + FeatureAblation, |
| 119 | + FeaturePermutation, |
| 120 | + ], |
| 121 | + "model": BasicModel_MultiLayer(), |
| 122 | + "attribute_args": { |
| 123 | + "inputs": torch.randn(4, 3), |
| 124 | + "target": 1, |
| 125 | + "enable_cross_tensor_attribution": True, |
| 126 | + }, |
| 127 | + }, |
115 | 128 | {
|
116 | 129 | "name": "basic_multi_input",
|
117 | 130 | "algorithms": [
|
|
179 | 192 | },
|
180 | 193 | "dp_delta": 0.0005,
|
181 | 194 | },
|
| 195 | + { |
| 196 | + "name": "basic_multi_input_multi_target_cross_tensor_attributions", |
| 197 | + "algorithms": [ |
| 198 | + FeatureAblation, |
| 199 | + FeaturePermutation, |
| 200 | + ], |
| 201 | + "model": BasicModel_MultiLayer_MultiInput(), |
| 202 | + "attribute_args": { |
| 203 | + "inputs": (10 * torch.randn(6, 3), 5 * torch.randn(6, 3)), |
| 204 | + "additional_forward_args": (2 * torch.randn(6, 3), 5), |
| 205 | + "target": [0, 1, 1, 0, 0, 1], |
| 206 | + "enable_cross_tensor_attribution": True, |
| 207 | + }, |
| 208 | + "dp_delta": 0.0005, |
| 209 | + }, |
182 | 210 | {
|
183 | 211 | "name": "basic_multiple_tuple_target",
|
184 | 212 | "algorithms": [
|
|
202 | 230 | "additional_forward_args": (None, True),
|
203 | 231 | },
|
204 | 232 | },
|
| 233 | + { |
| 234 | + "name": "basic_multiple_tuple_target_cross_tensor_attributions", |
| 235 | + "algorithms": [ |
| 236 | + FeatureAblation, |
| 237 | + FeaturePermutation, |
| 238 | + ], |
| 239 | + "model": BasicModel_MultiLayer(), |
| 240 | + "attribute_args": { |
| 241 | + "inputs": torch.randn(4, 3), |
| 242 | + "target": [(1, 0, 0), (0, 1, 1), (1, 1, 1), (0, 0, 0)], |
| 243 | + "additional_forward_args": (None, True), |
| 244 | + "enable_cross_tensor_attribution": True, |
| 245 | + }, |
| 246 | + }, |
205 | 247 | {
|
206 | 248 | "name": "basic_tensor_single_target",
|
207 | 249 | "algorithms": [
|
|
243 | 285 | "target": torch.tensor([1, 1, 0, 0]),
|
244 | 286 | },
|
245 | 287 | },
|
| 288 | + { |
| 289 | + "name": "basic_tensor_multi_target_cross_tensor_attributions", |
| 290 | + "algorithms": [ |
| 291 | + FeatureAblation, |
| 292 | + FeaturePermutation, |
| 293 | + ], |
| 294 | + "model": BasicModel_MultiLayer(), |
| 295 | + "attribute_args": { |
| 296 | + "inputs": torch.randn(4, 3), |
| 297 | + "target": torch.tensor([1, 1, 0, 0]), |
| 298 | + "enable_cross_tensor_attribution": True, |
| 299 | + }, |
| 300 | + }, |
246 | 301 | # Primary Configs with Baselines
|
247 | 302 | {
|
248 | 303 | "name": "basic_multiple_tuple_target_with_baselines",
|
|
262 | 317 | "additional_forward_args": (None, True),
|
263 | 318 | },
|
264 | 319 | },
|
| 320 | + { |
| 321 | + "name": "basic_multiple_tuple_target_with_baselines_cross_tensor_attributions", |
| 322 | + "algorithms": [ |
| 323 | + FeatureAblation, |
| 324 | + ], |
| 325 | + "model": BasicModel_MultiLayer(), |
| 326 | + "attribute_args": { |
| 327 | + "inputs": torch.randn(4, 3), |
| 328 | + "baselines": 0.5 * torch.randn(4, 3), |
| 329 | + "target": [(1, 0, 0), (0, 1, 1), (1, 1, 1), (0, 0, 0)], |
| 330 | + "additional_forward_args": (None, True), |
| 331 | + "enable_cross_tensor_attribution": True, |
| 332 | + }, |
| 333 | + }, |
265 | 334 | {
|
266 | 335 | "name": "basic_tensor_single_target_with_baselines",
|
267 | 336 | "algorithms": [
|
|
279 | 348 | "target": torch.tensor([0]),
|
280 | 349 | },
|
281 | 350 | },
|
| 351 | + { |
| 352 | + "name": "basic_tensor_single_target_with_baselines_cross_tensor_attributions", |
| 353 | + "algorithms": [ |
| 354 | + FeatureAblation, |
| 355 | + ], |
| 356 | + "model": BasicModel_MultiLayer(), |
| 357 | + "attribute_args": { |
| 358 | + "inputs": torch.randn(4, 3), |
| 359 | + "baselines": 0.5 * torch.randn(4, 3), |
| 360 | + "target": torch.tensor([0]), |
| 361 | + "enable_cross_tensor_attribution": True, |
| 362 | + }, |
| 363 | + }, |
282 | 364 | # Primary Configs with Internal Batching
|
283 | 365 | {
|
284 | 366 | "name": "basic_multiple_tuple_target_with_internal_batching",
|
|
0 commit comments