@@ -851,6 +851,57 @@ def _influence_route_to_helpers(
851
851
)
852
852
853
853
854
+ def _parameter_dot (
855
+ params_1 : Tuple [Tensor , ...], params_2 : Tuple [Tensor , ...]
856
+ ) -> Tensor :
857
+ """
858
+ returns the dot-product of 2 tensors, represented as tuple of tensors.
859
+ """
860
+ return torch .Tensor (
861
+ sum (
862
+ torch .sum (param_1 * param_2 )
863
+ for (param_1 , param_2 ) in zip (params_1 , params_2 )
864
+ )
865
+ )
866
+
867
+
868
+ def _parameter_add (
869
+ params_1 : Tuple [Tensor , ...], params_2 : Tuple [Tensor , ...]
870
+ ) -> Tuple [Tensor , ...]:
871
+ """
872
+ returns the sum of 2 tensors, represented as tuple of tensors.
873
+ """
874
+ return tuple (param_1 + param_2 for (param_1 , param_2 ) in zip (params_1 , params_2 ))
875
+
876
+
877
+ def _parameter_multiply (params : Tuple [Tensor , ...], c : Tensor ) -> Tuple [Tensor , ...]:
878
+ """
879
+ multiplies all tensors in a tuple of tensors by a given scalar
880
+ """
881
+ return tuple (param * c for param in params )
882
+
883
+
884
+ def _parameter_to (params : Tuple [Tensor , ...], ** to_kwargs ) -> Tuple [Tensor , ...]:
885
+ """
886
+ applies the `to` method to all tensors in a tuple of tensors
887
+ """
888
+ return tuple (param .to (** to_kwargs ) for param in params )
889
+
890
+
891
+ def _parameter_linear_combination (
892
+ paramss : List [Tuple [Tensor , ...]], cs : Tensor
893
+ ) -> Tuple [Tensor , ...]:
894
+ """
895
+ scales each parameter (tensor of tuples) in a list by the corresponding scalar in a
896
+ 1D tensor of the same length, and sums up the scaled parameters
897
+ """
898
+ assert len (cs .shape ) == 1
899
+ result = _parameter_multiply (paramss [0 ], cs [0 ])
900
+ for (params , c ) in zip (paramss [1 :], cs [1 :]):
901
+ result = _parameter_add (result , _parameter_multiply (params , c ))
902
+ return result
903
+
904
+
854
905
def _compute_jacobian_sample_wise_grads_per_batch (
855
906
influence_inst : Union ["TracInCP" , "InfluenceFunctionBase" ],
856
907
inputs : Tuple [Any , ...],
@@ -1007,7 +1058,9 @@ def _functional_call(model, d, features):
1007
1058
def _dataset_fn (dataloader , batch_fn , reduce_fn , * batch_fn_args , ** batch_fn_kwargs ):
1008
1059
"""
1009
1060
Applies `batch_fn` to each batch in `dataloader`, reducing the results using
1010
- `reduce_fn`. This is useful for computing Hessians over an entire dataloader.
1061
+ `reduce_fn`. This is useful for computing Hessians and Hessian-vector
1062
+ products over an entire dataloader, and is used by both `NaiveInfluenceFunction`
1063
+ and `ArnoldiInfluenceFunction`.
1011
1064
"""
1012
1065
_dataloader = iter (dataloader )
1013
1066
0 commit comments