Skip to content

Commit 68d88cf

Browse files
Fulton Wangfacebook-github-bot
authored andcommitted
add ArnoldiInfluenceFunction (#1187)
Summary: Pull Request resolved: #1187 This diff implements `ArnoldiInfluenceFunction`, which was described, along with `NaiveInfluenceFunction` in D40541294. Please see that diff for detailed description. Previously implementations of both methods had been 1 diff. Now, `ArnoldiInfluenceFunction` is separated out for easier review. Reviewed By: vivekmig Differential Revision: D42006733 fbshipit-source-id: 14e82d30d56fb75dcdb5e77db9c93d626430a74f
1 parent bd1b4c6 commit 68d88cf

File tree

8 files changed

+1694
-10
lines changed

8 files changed

+1694
-10
lines changed

captum/influence/_core/arnoldi_influence_function.py

Lines changed: 1022 additions & 0 deletions
Large diffs are not rendered by default.

captum/influence/_utils/common.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,57 @@ def _influence_route_to_helpers(
859859
)
860860

861861

862+
def _parameter_dot(
863+
params_1: Tuple[Tensor, ...], params_2: Tuple[Tensor, ...]
864+
) -> Tensor:
865+
"""
866+
returns the dot-product of 2 tensors, represented as tuple of tensors.
867+
"""
868+
return torch.tensor(
869+
sum(
870+
torch.sum(param_1 * param_2)
871+
for (param_1, param_2) in zip(params_1, params_2)
872+
)
873+
)
874+
875+
876+
def _parameter_add(
877+
params_1: Tuple[Tensor, ...], params_2: Tuple[Tensor, ...]
878+
) -> Tuple[Tensor, ...]:
879+
"""
880+
returns the sum of 2 tensors, represented as tuple of tensors.
881+
"""
882+
return tuple(param_1 + param_2 for (param_1, param_2) in zip(params_1, params_2))
883+
884+
885+
def _parameter_multiply(params: Tuple[Tensor, ...], c: Tensor) -> Tuple[Tensor, ...]:
886+
"""
887+
multiplies all tensors in a tuple of tensors by a given scalar
888+
"""
889+
return tuple(param * c for param in params)
890+
891+
892+
def _parameter_to(params: Tuple[Tensor, ...], **to_kwargs) -> Tuple[Tensor, ...]:
893+
"""
894+
applies the `to` method to all tensors in a tuple of tensors
895+
"""
896+
return tuple(param.to(**to_kwargs) for param in params)
897+
898+
899+
def _parameter_linear_combination(
900+
paramss: List[Tuple[Tensor, ...]], cs: Tensor
901+
) -> Tuple[Tensor, ...]:
902+
"""
903+
scales each parameter (tensor of tuples) in a list by the corresponding scalar in a
904+
1D tensor of the same length, and sums up the scaled parameters
905+
"""
906+
assert len(cs.shape) == 1
907+
result = _parameter_multiply(paramss[0], cs[0])
908+
for (params, c) in zip(paramss[1:], cs[1:]):
909+
result = _parameter_add(result, _parameter_multiply(params, c))
910+
return result
911+
912+
862913
def _compute_jacobian_sample_wise_grads_per_batch(
863914
influence_inst: Union["TracInCP", "InfluenceFunctionBase"],
864915
inputs: Tuple[Any, ...],
@@ -1015,7 +1066,9 @@ def _functional_call(model, d, features):
10151066
def _dataset_fn(dataloader, batch_fn, reduce_fn, *batch_fn_args, **batch_fn_kwargs):
10161067
"""
10171068
Applies `batch_fn` to each batch in `dataloader`, reducing the results using
1018-
`reduce_fn`. This is useful for computing Hessians over an entire dataloader.
1069+
`reduce_fn`. This is useful for computing Hessians and Hessian-vector
1070+
products over an entire dataloader, and is used by both `NaiveInfluenceFunction`
1071+
and `ArnoldiInfluenceFunction`.
10191072
"""
10201073
_dataloader = iter(dataloader)
10211074

0 commit comments

Comments
 (0)