@@ -353,10 +353,12 @@ def attribute(
353
353
formatted_feature_mask ,
354
354
attr_progress ,
355
355
flattened_initial_eval ,
356
+ initial_eval ,
356
357
n_outputs ,
357
358
total_attrib ,
358
359
weights ,
359
360
attrib_type ,
361
+ perturbations_per_eval ,
360
362
** kwargs ,
361
363
)
362
364
else :
@@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks(
470
472
formatted_feature_mask : Tuple [Tensor , ...],
471
473
attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
472
474
flattened_initial_eval : Tensor ,
475
+ initial_eval : Tensor ,
473
476
n_outputs : int ,
474
477
total_attrib : List [Tensor ],
475
478
weights : List [Tensor ],
476
479
attrib_type : dtype ,
480
+ perturbations_per_eval : int ,
477
481
** kwargs : Any ,
478
482
) -> Tuple [List [Tensor ], List [Tensor ]]:
479
483
feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
@@ -482,17 +486,66 @@ def _attribute_with_cross_tensor_feature_masks(
482
486
if feature_idx .item () not in feature_idx_to_tensor_idx :
483
487
feature_idx_to_tensor_idx [feature_idx .item ()] = []
484
488
feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
489
+ all_feature_idxs = list (feature_idx_to_tensor_idx .keys ())
490
+ additional_args_repeated : object
491
+ if perturbations_per_eval > 1 :
492
+ additional_args_repeated = (
493
+ _expand_additional_forward_args (
494
+ formatted_additional_forward_args , perturbations_per_eval
495
+ )
496
+ if formatted_additional_forward_args is not None
497
+ else None
498
+ )
499
+ target_repeated = _expand_target (target , perturbations_per_eval )
500
+ else :
501
+ additional_args_repeated = formatted_additional_forward_args
502
+ target_repeated = target
503
+ num_examples = formatted_inputs [0 ].shape [0 ]
504
+
505
+ current_additional_args : object
506
+ # Process one feature per time, rather than processing every input tensor
507
+ for i in range (0 , len (all_feature_idxs ), perturbations_per_eval ):
508
+ current_feature_idxs = all_feature_idxs [i : i + perturbations_per_eval ]
509
+ current_num_ablated_features = min (
510
+ perturbations_per_eval , len (current_feature_idxs )
511
+ )
512
+
513
+ # Store appropriate inputs and additional args based on batch size.
514
+ if current_num_ablated_features != perturbations_per_eval :
515
+ current_additional_args = (
516
+ _expand_additional_forward_args (
517
+ formatted_additional_forward_args , current_num_ablated_features
518
+ )
519
+ if formatted_additional_forward_args is not None
520
+ else None
521
+ )
522
+ current_target = _expand_target (target , current_num_ablated_features )
523
+ else :
524
+ current_additional_args = additional_args_repeated
525
+ current_target = target_repeated
526
+
527
+ current_inputs = ()
528
+ current_masks = []
529
+ for (
530
+ single_perturb_input ,
531
+ single_perturb_masks ,
532
+ ) in self ._ablation_generator (
533
+ formatted_inputs ,
534
+ baselines ,
535
+ formatted_feature_mask ,
536
+ current_feature_idxs ,
537
+ feature_idx_to_tensor_idx ,
538
+ ** kwargs ,
539
+ ):
540
+ if len (current_inputs ) == 0 :
541
+ current_inputs = single_perturb_input
542
+ else :
543
+ current_inputs = tuple (
544
+ torch .cat ((current_inputs [j ], single_perturb_input [j ]), dim = 0 )
545
+ for j in range (len (current_inputs ))
546
+ )
547
+ current_masks .append (list (single_perturb_masks ))
485
548
486
- for (
487
- current_inputs ,
488
- current_mask ,
489
- ) in self ._ablation_generator (
490
- formatted_inputs ,
491
- baselines ,
492
- formatted_feature_mask ,
493
- feature_idx_to_tensor_idx ,
494
- ** kwargs ,
495
- ):
496
549
# modified_eval has (n_feature_perturbed * n_outputs) elements
497
550
# shape:
498
551
# agg mode: (*initial_eval.shape)
@@ -501,8 +554,8 @@ def _attribute_with_cross_tensor_feature_masks(
501
554
modified_eval = _run_forward (
502
555
self .forward_func ,
503
556
current_inputs ,
504
- target ,
505
- formatted_additional_forward_args ,
557
+ current_target ,
558
+ current_additional_args ,
506
559
)
507
560
508
561
if attr_progress is not None :
@@ -515,13 +568,16 @@ def _attribute_with_cross_tensor_feature_masks(
515
568
516
569
total_attrib , weights = self ._process_ablated_out_full (
517
570
modified_eval ,
518
- current_mask ,
571
+ current_masks ,
519
572
flattened_initial_eval ,
520
- formatted_inputs ,
573
+ initial_eval ,
574
+ current_inputs ,
521
575
n_outputs ,
576
+ num_examples ,
522
577
total_attrib ,
523
578
weights ,
524
579
attrib_type ,
580
+ perturbations_per_eval ,
525
581
)
526
582
return total_attrib , weights
527
583
@@ -530,6 +586,7 @@ def _ablation_generator(
530
586
inputs : Tuple [Tensor , ...],
531
587
baselines : BaselineType ,
532
588
input_mask : Tuple [Tensor , ...],
589
+ feature_idxs : List [int ],
533
590
feature_idx_to_tensor_idx : Dict [int , List [int ]],
534
591
** kwargs : Any ,
535
592
) -> Generator [
@@ -540,11 +597,8 @@ def _ablation_generator(
540
597
None ,
541
598
None ,
542
599
]:
543
- if isinstance (baselines , torch .Tensor ):
544
- baselines = baselines .reshape ((1 ,) + tuple (baselines .shape ))
545
-
546
600
# Process one feature per time, rather than processing every input tensor
547
- for feature_idx in feature_idx_to_tensor_idx . keys () :
601
+ for feature_idx in feature_idxs :
548
602
ablated_inputs , current_masks = (
549
603
self ._construct_ablated_input_across_tensors (
550
604
inputs ,
@@ -784,7 +838,7 @@ def _attribute_progress_setup(
784
838
formatted_inputs , feature_mask , ** kwargs
785
839
)
786
840
total_forwards = (
787
- int (sum (feature_counts ))
841
+ math . ceil ( int (sum (feature_counts )) / perturbations_per_eval )
788
842
if enable_cross_tensor_attribution
789
843
else sum (
790
844
math .ceil (count / perturbations_per_eval ) for count in feature_counts
@@ -1187,43 +1241,76 @@ def _process_ablated_out(
1187
1241
weights [i ] += current_mask .float ().sum (dim = 0 )
1188
1242
1189
1243
total_attrib [i ] += (eval_diff * current_mask .to (attrib_type )).sum (dim = 0 )
1244
+ print (i , weights )
1190
1245
return total_attrib , weights
1191
1246
1192
1247
def _process_ablated_out_full (
1193
1248
self ,
1194
1249
modified_eval : Tensor ,
1195
- current_mask : Tuple [ Optional [Tensor ], ... ],
1250
+ current_mask : List [ List [ Optional [Tensor ]] ],
1196
1251
flattened_initial_eval : Tensor ,
1252
+ initial_eval : Tensor ,
1197
1253
inputs : TensorOrTupleOfTensorsGeneric ,
1198
1254
n_outputs : int ,
1255
+ num_examples : int ,
1199
1256
total_attrib : List [Tensor ],
1200
1257
weights : List [Tensor ],
1201
1258
attrib_type : dtype ,
1259
+ perturbations_per_eval : int ,
1202
1260
) -> Tuple [List [Tensor ], List [Tensor ]]:
1203
1261
modified_eval = self ._parse_forward_out (modified_eval )
1262
+ # if perturbations_per_eval > 1, the output shape must grow with
1263
+ # input and not be aggregated
1264
+ current_batch_size = inputs [0 ].shape [0 ]
1265
+
1266
+ # number of perturbation, which is not the same as
1267
+ # perturbations_per_eval when not enough features to perturb
1268
+ n_perturb = current_batch_size / num_examples
1269
+ if perturbations_per_eval > 1 and not self ._is_output_shape_valid :
1270
+
1271
+ current_output_shape = modified_eval .shape
1272
+
1273
+ # use initial_eval as the forward of perturbations_per_eval = 1
1274
+ initial_output_shape = initial_eval .shape
1275
+
1276
+ assert (
1277
+ # check if the output is not a scalar
1278
+ current_output_shape
1279
+ and initial_output_shape
1280
+ # check if the output grow in same ratio, i.e., not agg
1281
+ and current_output_shape [0 ] == n_perturb * initial_output_shape [0 ]
1282
+ ), (
1283
+ "When perturbations_per_eval > 1, forward_func's output "
1284
+ "should be a tensor whose 1st dim grow with the input "
1285
+ f"batch size: when input batch size is { num_examples } , "
1286
+ f"the output shape is { initial_output_shape } ; "
1287
+ f"when input batch size is { current_batch_size } , "
1288
+ f"the output shape is { current_output_shape } "
1289
+ )
1290
+
1291
+ self ._is_output_shape_valid = True
1204
1292
1205
1293
# reshape the leading dim for n_feature_perturbed
1206
1294
# flatten each feature's eval outputs into 1D of (n_outputs)
1207
1295
modified_eval = modified_eval .reshape (- 1 , n_outputs )
1208
1296
# eval_diff in shape (n_feature_perturbed, n_outputs)
1209
1297
eval_diff = flattened_initial_eval - modified_eval
1210
- eval_diff_shape = eval_diff .shape
1211
-
1212
- # append the shape of one input example
1213
- # to make it broadcastable to mask
1214
1298
1215
- if self .use_weights :
1216
- for weight , mask in zip (weights , current_mask ):
1217
- if mask is not None :
1218
- weight += mask .float ().sum (dim = 0 )
1219
- for i , mask in enumerate (current_mask ):
1220
- if mask is None or inputs [i ].numel () == 0 :
1221
- continue
1222
- eval_diff = eval_diff .reshape (
1223
- eval_diff_shape + (inputs [i ].dim () - 1 ) * (1 ,)
1224
- )
1225
- eval_diff = eval_diff .to (total_attrib [i ].device )
1226
- total_attrib [i ] += (eval_diff * mask .to (attrib_type )).sum (dim = 0 )
1299
+ for j in range (int (n_perturb )):
1300
+ single_perturb_mask = current_mask [j ]
1301
+ if self .use_weights :
1302
+ for weight , mask in zip (weights , single_perturb_mask ):
1303
+ if mask is not None :
1304
+ weight += mask .float ()
1305
+ for i , mask in enumerate (single_perturb_mask ):
1306
+ this_input = inputs [i ][j * num_examples : (j + 1 ) * num_examples ]
1307
+ if mask is None or this_input .numel () == 0 :
1308
+ continue
1309
+ eval_diff_j = eval_diff [j ].reshape (
1310
+ eval_diff [j ].shape + (this_input .dim () - 1 ) * (1 ,)
1311
+ )
1312
+ eval_diff_j = eval_diff_j .to (total_attrib [i ].device )
1313
+ total_attrib [i ] += eval_diff_j * mask .to (attrib_type )
1227
1314
return total_attrib , weights
1228
1315
1229
1316
def _fut_tuple_to_accumulate_fut_list (
0 commit comments