2626from captum ._utils .progress import NullProgress , progress
2727from captum .influence ._core .influence import DataInfluence
2828from captum .influence ._utils .common import (
29+ _check_loss_fn ,
2930 _format_inputs_dataset ,
3031 _get_k_most_influential_helper ,
3132 _gradient_dot_product ,
@@ -102,6 +103,7 @@ def __init__(
102103 checkpoints_load_func : Callable = _load_flexible_state_dict ,
103104 loss_fn : Optional [Union [Module , Callable ]] = None ,
104105 batch_size : Union [int , None ] = 1 ,
106+ test_loss_fn : Optional [Union [Module , Callable ]] = None ,
105107 ) -> None :
106108 r"""
107109 Args:
@@ -152,6 +154,19 @@ def __init__(
152154 `train_dataset` is a Dataset. If `train_dataset`
153155 is a DataLoader, then `batch_size` is ignored as an argument.
154156 Default: 1
157+ test_loss_fn (Callable, optional): In some cases, one may want to use a
158+ separate loss functions for training examples, i.e. those in
159+ `train_dataset`, and for test examples, i.e. those
160+ represented by the `inputs` and `targets` arguments to the
161+ `influence` method. For example, if one wants to calculate the
162+ influence score of a training example on a test example's
163+ prediction for a fixed class, `test_loss_fn` could map from the
164+ logits for all classes to the logits for a fixed class.
165+ `test_loss_fn` needs to satisfy the same constraints as `loss_fn`.
166+ If not provided, the loss function for test examples is assumed to
167+ be the same as the loss function for training examples, i.e.
168+ `loss_fn`.
169+ Default: None
155170 """
156171
157172 self .model = model
@@ -167,6 +182,8 @@ def __init__(
167182
168183 self .checkpoints_load_func = checkpoints_load_func
169184 self .loss_fn = loss_fn
185+ # If test_loss_fn not provided, it's assumed to be same as loss_fn
186+ self .test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn
170187 self .batch_size = batch_size
171188
172189 if not isinstance (train_dataset , DataLoader ):
@@ -489,6 +506,7 @@ def __init__(
489506 layers : Optional [List [str ]] = None ,
490507 loss_fn : Optional [Union [Module , Callable ]] = None ,
491508 batch_size : Union [int , None ] = 1 ,
509+ test_loss_fn : Optional [Union [Module , Callable ]] = None ,
492510 sample_wise_grads_per_batch : bool = False ,
493511 ) -> None :
494512 r"""
@@ -561,6 +579,24 @@ def __init__(
561579 `train_dataset` is a Dataset. If `train_dataset`
562580 is a DataLoader, then `batch_size` is ignored as an argument.
563581 Default: 1
582+ test_loss_fn (Callable, optional): In some cases, one may want to use a
583+ separate loss functions for training examples, i.e. those in
584+ `train_dataset`, and for test examples, i.e. those
585+ represented by the `inputs` and `targets` arguments to the
586+ `influence` method. For example, if one wants to calculate the
587+ influence score of a training example on a test example's
588+ prediction for a fixed class, `test_loss_fn` could map from the
589+ logits for all classes to the logits for a fixed class.
590+ `test_loss_fn` needs satisfy the same constraints as `loss_fn`.
591+ Thus, the same checks that we apply to `loss_fn` are also applied
592+ to `test_loss_fn`, if the latter is provided. Note that the
593+ constraints on both `loss_fn` and `test_loss_fn` both depend on
594+ `sample_wise_grads_per_batch`. This means `loss_fn` and
595+ `test_loss_fn` must either both be "per-example" loss functions,
596+ or both be "reduction" loss functions. If not provided, the loss
597+ function for test examples is assumed to be the same as the loss
598+ function for training examples, i.e. `loss_fn`.
599+ Default: None
564600 sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient
565601 computations w.r.t. model parameters aggregates the results for a
566602 batch and does not allow to access sample-wise gradients w.r.t.
@@ -590,51 +626,23 @@ def __init__(
590626 checkpoints_load_func ,
591627 loss_fn ,
592628 batch_size ,
629+ test_loss_fn ,
593630 )
594631
595632 self .sample_wise_grads_per_batch = sample_wise_grads_per_batch
596633
597- # If we are able to access the reduction used by `loss_fn`, we check whether
598- # the reduction is compatible with `sample_wise_grads_per_batch`
599- if isinstance (loss_fn , Module ) and hasattr (
600- loss_fn , "reduction"
601- ): # TODO: allow loss_fn to be Callable
602- if self .sample_wise_grads_per_batch :
603- assert loss_fn .reduction in ["sum" , "mean" ], (
604- 'reduction for `loss_fn` must be "sum" or "mean" when '
605- "`sample_wise_grads_per_batch` is True"
606- )
607- self .reduction_type = str (loss_fn .reduction )
608- else :
609- assert loss_fn .reduction == "none" , (
610- 'reduction for `loss_fn` must be "none" when '
611- "`sample_wise_grads_per_batch` is False"
612- )
613- else :
614- # if we are unable to access the reduction used by `loss_fn`, we warn
615- # the user about the assumptions we are making regarding the reduction
616- # used by `loss_fn`
617- if self .sample_wise_grads_per_batch :
618- warnings .warn (
619- 'Since `loss_fn` has no "reduction" attribute, and '
620- "`sample_wise_grads_per_batch` is True, the implementation assumes "
621- 'that `loss_fn` is a "reduction" loss function that reduces the '
622- "per-example losses by taking their *sum*. If `loss_fn` "
623- "instead reduces the per-example losses by taking their mean, "
624- 'please set the reduction attribute of `loss_fn` to "mean", i.e. '
625- '`loss_fn.reduction = "mean"`. Note that if '
626- "`sample_wise_grads_per_batch` is True, the implementation "
627- "assumes the reduction is either a sum or mean reduction."
628- )
629- self .reduction_type = "sum"
630- else :
631- warnings .warn (
632- 'Since `loss_fn` has no "reduction" attribute, and '
633- "`sample_wise_grads_per_batch` is False, the implementation "
634- 'assumes that `loss_fn` is a "per-example" loss function (see '
635- "documentation for `loss_fn` for details). Please ensure that "
636- "this is the case."
637- )
634+ # check `loss_fn`
635+ self .reduction_type = _check_loss_fn (
636+ self , loss_fn , "loss_fn" , sample_wise_grads_per_batch
637+ )
638+ # check `test_loss_fn` if it was provided
639+ self .test_reduction_type = (
640+ self .reduction_type
641+ if test_loss_fn is None
642+ else _check_loss_fn (
643+ self , test_loss_fn , "test_loss_fn" , sample_wise_grads_per_batch
644+ )
645+ )
638646
639647 r"""
640648 TODO: Either restore model state after done (would have to place functionality
@@ -790,11 +798,15 @@ def get_checkpoint_contribution(checkpoint):
790798 input_jacobians = self ._basic_computation_tracincp (
791799 inputs ,
792800 targets ,
801+ self .test_loss_fn ,
802+ self .test_reduction_type ,
793803 )
794804 return (
795805 _gradient_dot_product (
796806 input_jacobians ,
797- self ._basic_computation_tracincp (batch [0 :- 1 ], batch [- 1 ]),
807+ self ._basic_computation_tracincp (
808+ batch [0 :- 1 ], batch [- 1 ], self .loss_fn , self .reduction_type
809+ ),
798810 )
799811 * learning_rate
800812 )
@@ -1042,7 +1054,10 @@ def get_checkpoint_contribution(checkpoint):
10421054 for batch in _inputs_dataset :
10431055
10441056 layer_jacobians = self ._basic_computation_tracincp (
1045- batch [0 :- 1 ], batch [- 1 ]
1057+ batch [0 :- 1 ],
1058+ batch [- 1 ],
1059+ self .loss_fn ,
1060+ self .reduction_type ,
10461061 )
10471062
10481063 # Note that all variables in this function are for an entire batch.
@@ -1179,11 +1194,14 @@ def _basic_computation_tracincp(
11791194 self ,
11801195 inputs : Tuple [Any , ...],
11811196 targets : Optional [Tensor ] = None ,
1197+ loss_fn : Optional [Union [Module , Callable ]] = None ,
1198+ reduction_type : Optional [str ] = None ,
11821199 ) -> Tuple [Tensor , ...]:
11831200 """
11841201 For instances of TracInCP, computation of influence scores or self influence
11851202 scores repeatedly calls this function for different checkpoints
1186- and batches.
1203+ and batches. In particular, this function computes the jacobian of a loss
1204+ function w.r.t. parameters in the `layers` initialization argument.
11871205
11881206 Args:
11891207
@@ -1193,20 +1211,26 @@ def _basic_computation_tracincp(
11931211 that `model(*inputs)` produces the predictions for the batch.
11941212 targets (tensor or None): If computing influence scores on a loss function,
11951213 these are the labels corresponding to the batch `inputs`.
1214+ Default: none
1215+ loss_fn (Callable, optional): The loss function to use when computing the
1216+ jacobian.
1217+ reduction_type (str, optional): The reduction type of `loss_fn`. This
1218+ argument is only used if `sample_wise_grads_per_batch` was true in
1219+ initialization.
11961220 """
11971221 if self .sample_wise_grads_per_batch :
11981222 return _compute_jacobian_wrt_params_with_sample_wise_trick (
11991223 self .model ,
12001224 inputs ,
12011225 targets ,
1202- self . loss_fn ,
1203- self . reduction_type ,
1226+ loss_fn ,
1227+ reduction_type ,
12041228 self .layer_modules ,
12051229 )
12061230 return _compute_jacobian_wrt_params (
12071231 self .model ,
12081232 inputs ,
12091233 targets ,
1210- self . loss_fn ,
1234+ loss_fn ,
12111235 self .layer_modules ,
12121236 )
0 commit comments