diff --git a/captum/influence/__init__.py b/captum/influence/__init__.py index ac2c40a618..506851fe1b 100644 --- a/captum/influence/__init__.py +++ b/captum/influence/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from captum.influence._core.influence import DataInfluence # noqa +from captum.influence._core.influence_function import NaiveInfluenceFunction # noqa from captum.influence._core.similarity_influence import SimilarityInfluence # noqa from captum.influence._core.tracincp import TracInCP, TracInCPBase # noqa from captum.influence._core.tracincp_fast_rand_proj import ( @@ -15,4 +16,5 @@ "TracInCP", "TracInCPFast", "TracInCPFastRandProj", + "NaiveInfluenceFunction", ] diff --git a/captum/influence/_core/arnoldi_influence_function.py b/captum/influence/_core/arnoldi_influence_function.py new file mode 100644 index 0000000000..25cdce1316 --- /dev/null +++ b/captum/influence/_core/arnoldi_influence_function.py @@ -0,0 +1,1022 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +import functools +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch + +from captum._utils.gradient import _extract_parameters_from_layers + +from captum.influence._utils.common import ( + _compute_batch_loss_influence_function_base, + _compute_jacobian_sample_wise_grads_per_batch, + _dataset_fn, + _format_inputs_dataset, + _functional_call, + _get_k_most_influential_helper, + _influence_batch_intermediate_quantities_influence_function, + _influence_helper_intermediate_quantities_influence_function, + _influence_route_to_helpers, + _load_flexible_state_dict, + _parameter_add, + _parameter_dot, + _parameter_linear_combination, + _parameter_multiply, + _parameter_to, + _params_to_names, + _progress_bar_constructor, + _self_influence_helper_intermediate_quantities_influence_function, + _top_eigen, + KMostInfluentialResults, +) +from captum.log import log_usage + +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from .influence_function import ( + _get_dataset_embeddings_intermediate_quantities_influence_function, + InfluenceFunctionBase, + IntermediateQuantitiesInfluenceFunction, +) + + +def _parameter_arnoldi( + hvp: Callable, + b: Tuple[Tensor, ...], + n: int, + tol: float, + projection_device: torch.device, + show_progress: bool, +) -> Tuple[List[Tuple[Tensor, ...]], Tensor]: + r""" + Given `hvp`, a function which computes the Hessian-vector product of an arbitrary + vector `v` with an implicitly-defined Hessian matrix `A`, performs the Arnoldi + iteration for `A` for `n` iterations. (We use `A`, not `H` to refer to the + Hessian, unlike elsewhere, because `H` is already used in the below explanation + of the Arnoldi iteration.) + + For more details on the Arnoldi iteration, please see Trefethen and Bau, Chp 33. + Running Arnoldi iteration for n iterations gives a basis for the Krylov subspace + spanned by :math`\{b, Ab,..., A^{n-1}b\}`, as well as a `n+1` by `n` matrix + :math`H_n` which is upper Hessenberg (all entries below the diagonal, except those + adjoining it, are 0), whose first n rows represent the restriction of `A` to the + Krylov subspace, using the basis. Here, `b` is an arbitrary initialization basis + vector. The basis is assembled into a `D` by `n+1` matrix, where the last + column is a "correction factor", i.e. not part of the basis, denoted + :math`Q_{n+1}`. Letting :math`Q_n` denote the matrix with the first n columns of + :math`Q_{n+1}`, the following equality is satisfied: :math`A=Q_{n+1} H_n Q_n'`. + + In this implementation, `v` is not actually a vector, but instead a tuple of + tensors, because `hvp` being a Hessian-vector product, `v` lies in parameter-space, + which Pytorch represents as tuples of tensors. This implementation avoids + flattening `v` to a 1D tensor, which leads to scalability gains. + + Args: + hvp (Callable): A callable that accepts an arbitrary tuple of tensors + `v`, which represents a parameter, and returns + `Av`, i.e. the multiplication of `v` with an implicitly defined matrix + `A` of compatible dimension, which in practice is a Hessian-vector + product. + b (tensor): The Arnoldi iteration requires an initialization basis to + construct the basis, typically randomly chosen. This is that basis, + and is a tuple of tensors. We assume that the device of `b` is the same + as the required device of input `v` to `hvp`. For example, if `hvp` + computes HVP using a model that is on the GPU, then `b` should also be + on the GPU. + n (int): The number of iterations to run the iteration for. + tol (float, optional): After many iterations, the already-obtained + basis vectors may already approximately span the Krylov subspace, + in which case the addition of additional basis vectors involves + normalizing a vector with a small norm. These vectors are not + necessary to include in the basis and furthermore, their small norm + leads to numerical issues. Therefore we stop the Arnoldi iteration + when the addition of additional vectors involves normalizing a + vector with norm below a certain threshold. This argument specifies + that threshold. + Default: 1e-4 + projection_device (torch.device) The returned quantities (which will be used + to define a projection of parameter-gradients, hence the name) are + potentially memory intensive, because they represent a basis of a + subspace in the space of parameters, which are potentially + high-dimensional. Therefore we need to be careful of out-of-memory + GPU errors. This argument represents the device where the returned + quantities should be stored, and its choice requires balancing + speed with GPU memory. + show_progress (bool): If true, the progress of the iteration (i.e. number of + basis vectors already determined) will be displayed. It will try to + use tqdm if available for advanced features (e.g. time estimation). + Otherwise, it will fallback to a simple output of progress. + + Returns: + qs (list of tuple of tensors): A list of tuple of tensors, whose first `n` + elements contain a basis for the Krylov subspace. + H (tensor): A tensor with shape `(n+1, n)` whose first `n` rows represent + the restriction of `A` to the Krylov subspace. + """ + # because the HVP is the computational bottleneck, we always do HVP on + # the same device as the model, which is assumed to be the device `b` is on + computation_device = next(iter(b)).device + + # all entries of `b` have the same dtype, and so can be used to determine dtype + # of `H` + H = torch.zeros(n + 1, n, dtype=next(iter(b)).dtype).to(device=projection_device) + qs = [ + _parameter_to( + _parameter_multiply(b, 1.0 / _parameter_dot(b, b) ** 0.5), + device=projection_device, + ) + ] + + iterates = range(1, n + 1) + if show_progress: + iterates = tqdm(iterates, desc="Running Arnoldi Iteration for step") + + for k in iterates: + v = _parameter_to( + hvp(_parameter_to(qs[k - 1], device=computation_device)), + device=projection_device, + ) + + for i in range(k): + H[i, k - 1] = _parameter_dot(qs[i], v) + v = _parameter_add(v, _parameter_multiply(qs[i], -H[i, k - 1])) + H[k, k - 1] = _parameter_dot(v, v) ** 0.5 + + if H[k, k - 1] < tol: + break + qs.append(_parameter_multiply(v, 1.0 / H[k, k - 1])) + + return qs[:k], H[:k, : k - 1] + + +def _parameter_distill( + qs: List[Tuple[Tensor, ...]], + H: Tensor, + k: Optional[int], + hessian_reg: float, + hessian_inverse_tol: float, +): + """ + This takes the output of `_parameter_arnoldi`, and extracts the top-k eigenvalues + / eigenvectors of the matrix that `_parameter_arnoldi` found the Krylov subspace + for. In this documentation, we will refer to that matrix by `A`. + + Args: + qs (list of tuple of tensors): A list of tuple of tensors, whose first `N` + elements contain a basis for the Krylov subspace. + H (tensor): A tensor with shape `(N+1, N)` whose first `N` rows represent + the restriction of `A` to the Krylov subspace. + k (int): The number of top eigenvalues / eigenvectors to return. Note that the + actual number returned may be less, due to filtering based on + `hessian_inverse_tol`. + hessian_reg (float): hessian_reg (float): We add an entry to the diagonal of + `H` to encourage it to be positive definite. This is that entry. + hessian_inverse_tol (float): To compute the "square root" of `H` using the top + eigenvectors / eigenvalues, the eigenvalues should be positive, and + furthermore if above a tolerance, the inversion will be more + numerically stable. Therefore, we only return eigenvectors / + eigenvalues where the eigenvalue is above a tolerance. This argument + specifies that tolerance. We do not compute the square root in this + function, but assume the output of this function will be used for + computing it, hence the need for this argument. + + Returns: + (eigenvalues, eigenvectors) (tensor, list of tuple of tensors): `eigenvalues` + is a 1D tensor of the top eigenvalues of `A`. Note that due to + filtering based on `hessian_inverse_tol`, the actual number of + eigenvalues may be less than `k`. The eigenvalues are in ascending + order, mimicking the convention of `torch.linalg.eigh`. `eigenvectors` + are the corresponding eigenvectors. Since `A` represents the Hessian + of parameters, with the parameters represented as a tuple of tensors, + the eigenvectors, because they represent parameters, are also + tuples of tensors. Therefore, `eigenvectors` is a list of tuple of + tensors. + """ + # get rid of last basis of qs, last column of H, since they are not part of + # the decomposition + qs = qs[:-1] + H = H[:-1] + + # if arnoldi basis is empty, raise exception + if len(qs) == 0: + raise Exception( + "Arnoldi basis is empty. Consider increasing the `arnoldi_tol` argument" + ) + + # ls, vs are the top eigenvalues / eigenvectors. however, the eigenvectors are + # expressed as coordinates using the Krylov subspace basis, qs (each column of vs + # represents a different eigenvector). + ls, vs = _top_eigen(H, k, hessian_reg, hessian_inverse_tol) + + # if no positive eigenvalues exist, we cannot compute a low-rank + # approximation of the square root of the hessian H, so raise exception + if vs.shape[1] == 0: + raise Exception( + "Restriction of Hessian to Krylov subspace has no positive " + "eigenvalues, so cannot take its square root." + ) + + # we want to express the top eigenvectors as coordinates using the standard basis. + # each column of vs represents a different eigenvector, expressed as coordinates + # using the Krylov subspace basis. to express the eigenvector using the standard + # basis, we use it as the coefficients in a linear combination with the Krylov + # subspace basis, qs. + vs_standard = [_parameter_linear_combination(qs, v) for v in vs.T] + + return ls, vs_standard + + +class ArnoldiInfluenceFunction(IntermediateQuantitiesInfluenceFunction): + r""" + This is a computationally-efficient implementation that computes the type of + "infinitesimal" influence scored defined in the paper "Understanding Black-box + Predictions via Influence Functions" by Koh et al + (https://arxiv.org/pdf/1703.04730.pdf). This implementation does *not* follow + the approach in that paper, however. Instead, it follows an implementation that is + several orders of magnitudes faster, described in the paper "Scaling Up Influence + Functions" by Schioppa et al (https://arxiv.org/pdf/2112.03052.pdf). + + This implementation computes a low-rank approximation of the inverse Hessian, i.e. + a tall and skinny (with width k) matrix :math`R` such that + :math`H^{-1} \approx RR'`, where k is small. In particular, let :math`V` be the + matrix of width k whose columns contain the top-k eigenvectors of :math`H`, and let + :math`S` be the k by k matrix whose diagonals contain the corresponding eigenvalues. + This implementation lets :math`R=VS^{-0.5}`. Thus, the core computational step is + computing the top-k eigenvalues / eigenvectors. + + This approximation is useful for several reasons: + - It avoids numerical issues associated with inverting small eigenvalues + - Since the influence score is given by + :math`\nabla_\theta L(x)' H^{-1} \nabla_\theta L(z)`, which is approximated by + :math`(\nabla_\theta L(x)' R) (\nabla_\theta L(z)' R)`, we can compute an + "influence embedding" for a given example :math`x`, :math`\nabla_\theta L(x)' R`, + such that the influence score of one example on another is approximately the + dot-product of their respective embeddings. + - Even for large models, we can store `R` in memory, provided k is small. This + means influence embeddings (and thus influence scores) can be efficiently + computed by doing a backwards pass to compute :math`\nabla_\theta L(x)` and then + multiplying by :math`R'`. This is orders of magnitude faster than the previous + LISSA approach of Koh et al, which to compute the influence score involving a + given example, need to compute Hessian-vector products involving on the order + of 10^4 examples. + + The key novelty of the approach by Schioppa et al is that it uses the Arnoldi + iteration to find the top-k eigenvalues / eigenvectors of the Hessian without + explicitly forming the Hessian. In more detail, the approach first runs the + Arnoldi iteration, which only requires the ability to compute Hessian-vector + products, to find a Krylov subspace of moderate dimension, i.e. 200. It then finds + the top-k eigevalues / eigenvectors of the restriction of the Hessian to the + subspace, where k is small, i.e. 50. Finally, it expresses the eigenvectors in + the original basis. This approach for finding the top-k eigenvalues / eigenvectors + is justified by the property of the Arnoldi iteration, that the Krylov subspace + it returns tends to contain the top eigenvectors. + + This implementation require some computation time `__init__`, where it + runs the Arnoldi iteration to calculate `R`. This computation is linear in + `arnoldi_dim` as well as the size of `hessian_dataset`. After that initial + overhead, calculation of influence scores is quick, only requiring a backwards pass + and multiplication, per example. + + Unlike `NaiveInfluenceFunction`, this implementation does not flatten any + parameters, as the 2D Hessian is never formed, and Pytorch's Hessian-vector + implementation (`torch.autograd.functional.hvp`) allows the input and output + vector to be a tuple of tensors. Avoiding flattening / unflattening parameters + brings scalability gains. + """ + + def __init__( + self, + model: Module, + train_dataset: Union[Dataset, DataLoader], + checkpoint: str, + checkpoints_load_func: Callable = _load_flexible_state_dict, + layers: Optional[List[str]] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + hessian_dataset: Optional[Union[Dataset, DataLoader]] = None, + test_loss_fn: Optional[Union[Module, Callable]] = None, + sample_wise_grads_per_batch: bool = False, + projection_dim: int = 50, + seed: int = 0, + arnoldi_dim: int = 200, + arnoldi_tol: float = 1e-1, + hessian_reg: float = 1e-3, + hessian_inverse_tol: float = 1e-4, + projection_on_cpu: bool = True, + show_progress: bool = False, + ) -> None: + """ + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. + checkpoint (str): The path to the checkpoint used to compute influence + scores. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + layers (list[str] or None, optional): A list of layer names for which + gradients should be computed. If `layers` is None, gradients will + be computed for all layers. Otherwise, they will only be computed + for the layers specified in `layers`. + Default: None + loss_fn (Callable, optional): The loss function applied to model. For now, + we require it to be a "reduction='none'" loss function. For + example, `BCELoss(reduction='none')` would be acceptable, but + `BCELoss(reduction='sum')` would not. + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `train_dataset` and `hessian_dataset`, if they are + of type `Dataset`. `batch_size` should be chosen as large as + possible so that a backwards pass on a batch still fits in memory. + If `train_dataset` and `hessian_dataset`are both of type + `DataLoader`, then `batch_size` is ignored as an argument. + Default: 1 + hessian_dataset (Dataset or Dataloader, optional): The influence score and + self-influence scores this implementation calculates are defined in + terms of the Hessian, i.e. the second-derivative of the model + parameters. This argument provides the dataset used for calculating + the Hessian. It should be smaller than `train_dataset`, which + is the dataset whose examples we want the influence of. If not + provided or none, it will be assumed to be the same as + `train_dataset`. + Default: None + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + If not provided, the loss function for test examples is assumed to + be the same as the loss function for training examples, i.e. + `loss_fn`. + Default: None + sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient + computations w.r.t. model parameters aggregates the results for a + batch and does not allow to access sample-wise gradients w.r.t. + model parameters. This forces us to iterate over each sample in + the batch if we want sample-wise gradients which is computationally + inefficient. We offer an implementation of batch-wise gradient + computations w.r.t. to model parameters which is computationally + more efficient. This implementation can be enabled by setting the + `sample_wise_grad_per_batch` argument to `True`, and should be + enabled if and only if the `loss_fn` argument is a "reduction" loss + function. For example, `nn.BCELoss(reduction="sum")` would be a + valid `loss_fn` if this implementation is enabled (see + documentation for `loss_fn` for more details). Note that our + current implementation enables batch-wise gradient computations + only for a limited number of PyTorch nn.Modules: Conv2D and Linear. + This list will be expanded in the near future. Therefore, please + do not enable this implementation if gradients will be computed + for other kinds of layers. + Default: False + projection_dim (int, optional): This implementation produces a low-rank + approximation of the (inverse) Hessian. This is the rank of that + approximation, and also corresponds to the dimension of the + "influence embeddings" produced by the + `compute_intermediate_quantities` method. + Default: 50 + seed (int, optional): This implementation has a source of randomness - the + initialization basis to the Arnoldi iteration. This seed is used + to make that randomness reproducible. + Default: 42 + arnoldi_dim (int, optional): Calculating the low-rank approximation of the + (inverse) Hessian requires approximating the Hessian's top + eigenvectors / eigenvalues. This is done by first computing a + Krylov subspace via the Arnoldi iteration, and then finding the top + eigenvectors / eigenvalues of the restriction of the Hessian to the + Krylov subspace. Because only the top eigenvectors / eigenvalues + computed in the restriction will be similar to those in the full + space, `arnoldi_dim` should be chosen to be larger than + `projection_dim`. In the paper, they often choose `projection_dim` + to be between 10 and 100, and `arnoldi_dim` to be 200. Please see + the paper as well as Trefethen and Bau, Chapters 33-34 for more + details on the Arnoldi iteration. + Default: 200 + arnoldi_tol (float, optional): After many iterations, the already-obtained + basis vectors may already approximately span the Krylov subspace, + in which case the addition of additional basis vectors involves + normalizing a vector with a small norm. These vectors are not + necessary to include in the basis and furthermore, their small norm + leads to numerical issues. Therefore we stop the Arnoldi iteration + when the addition of additional vectors involves normalizing a + vector with norm below a certain threshold. This argument specifies + that threshold. + Default: 1e-4 + hessian_reg (float, optional): After computing the basis for the Krylov + subspace, the restriction of the Hessian to the subspace may not be + positive definite, which is required, as we compute a low-rank + approximation of its square root via eigen-decomposition. + `hessian_reg` adds an entry to the diagonals of the restriction of + the Hessian to encourage it to be positive definite. This argument + specifies that entry. Note that the regularized Hessian (i.e. with + `hessian_reg` added to its diagonals) does not actually need to be + positive definite - it just needs to have at least 1 positive + eigenvalue. + Default: 1e-3 + hessian_inverse_tol: (float) The tolerance to use when computing the + pseudo-inverse of the (square root of) hessian, restricted to the + Krylov subspace. + Default: 1e-4 + projection_on_cpu (bool, optional): Whether to move the projection, + i.e. low-rank approximation of the inverse Hessian, to cpu, to save + gpu memory. + Default: True + show_progress (bool, optional): In initialization, the Arnoldi iteration + and the subroutine it uses (calculating Hessian-vector products + over batches in `hessian_dataset`) can take a long time. If + `show_progress` is true, the progress of both computations + (number of steps in Arnoldi iteration, number of batches processed + in computing Hessian-vector products) will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + InfluenceFunctionBase.__init__( + self, + model, + train_dataset, + checkpoint, + checkpoints_load_func, + layers, + loss_fn, + batch_size, + hessian_dataset, + test_loss_fn, + sample_wise_grads_per_batch, + ) + + self.projection_dim = projection_dim + torch.manual_seed(seed) # for reproducibility + + self.arnoldi_dim = arnoldi_dim + self.arnoldi_tol = arnoldi_tol + self.hessian_reg = hessian_reg + self.hessian_inverse_tol = hessian_inverse_tol + + # infer the device the model is on. all parameters are assumed to be on the + # same device + self.model_device = next(model.parameters()).device + + self.R = self._retrieve_projections_arnoldi_influence_function( + self.hessian_dataloader, + projection_on_cpu, + show_progress, + ) + + def _retrieve_projections_arnoldi_influence_function( + self, + dataloader: DataLoader, + projection_on_cpu: bool, + show_progress: bool, + ) -> List[Tuple[Tensor, ...]]: + """ + + Returns the `R` described in the documentation for + `ArnoldiInfluenceFunction`. The returned `R` represents a set of + parameters in parameter space. However, since this implementation does *not* + flatten parameters, each of those parameters is represented as a tuple of + tensors. Therefore, `R` is represented as a list of tuple of tensors, and + can be viewed as a linear function that takes in a tuple of tensors + (representing a parameter), and returns a vector, where the i-th entry is + the dot-product (as it would be defined over tuple of tensors) of the parameter + (i.e. the input to the linear function) with the i-th entry of `R`. + + Can specify that projection should always be saved on cpu. if so, gradients are + always moved to same device as projections before multiplying (moving + projections to gpu when multiplying would defeat the purpose of moving them to + cpu to save gpu memory). + + Returns: + R (list of tuple of tensors): List of tuple of tensors of length + `projection_dim` (initialization argument). Each element + corresponds to a parameter in parameter-space, is represented as a + tuple of tensors, and together, define a projection that can be + applied to parameters (represented as tuple of tensors). + """ + # create function that computes hessian-vector product, given a vector + # represented as a tuple of tensors + + # first figure out names of params that require gradients. this is need to + # create that function, as it replaces params based on their names + params = tuple( + self.model.parameters() + if self.layer_modules is None + else _extract_parameters_from_layers(self.layer_modules) + ) + # the same position in `params` and `param_names` correspond to each other + param_names = _params_to_names(params, self.model) + + # get factory that given a batch, returns a function that given params as + # tuple of tensors, returns loss over the batch + def tensor_tuple_loss_given_batch(batch): + def tensor_tuple_loss(*params): + # `params` is a tuple of tensors, and assumed to be order specified by + # `param_names` + features, labels = tuple(batch[0:-1]), batch[-1] + + _output = _functional_call( + self.model, dict(zip(param_names, params)), features + ) + + # compute the total loss for the batch, adjusting the output of + # `self.loss_fn` based on `self.reduction_type` + return _compute_batch_loss_influence_function_base( + self.loss_fn, _output, labels, self.reduction_type + ) + + return tensor_tuple_loss + + # define function that given batch and vector, returns HVP of loss using the + # batch and vector + def batch_HVP(batch, v): + tensor_tuple_loss = tensor_tuple_loss_given_batch(batch) + return torch.autograd.functional.hvp(tensor_tuple_loss, params, v=v)[1] + + # define function that returns HVP of loss over `dataloader`, given a + # specified vector + def HVP(v): + _hvp = None + + _dataloader = dataloader + if show_progress: + _dataloader = tqdm( + dataloader, desc="processing `hessian_dataset` batch" + ) + + # the HVP of loss using the entire `dataloader` is the sum of the + # per-batch HVP's + return _dataset_fn(_dataloader, batch_HVP, _parameter_add, v) + + for batch in _dataloader: + hvp = batch_HVP(batch, v) + if _hvp is None: + _hvp = hvp + else: + _hvp = _parameter_add(_hvp, hvp) + return _hvp + + # now that can compute the hessian-vector product (of loss over `dataloader`), + # can perform arnoldi iteration + + # we always perform the HVP computations on the device where the model is. + # effectively this means we do the computations on gpu if available. this + # is necessary because the HVP is computationally expensive. + + # get initial random vector, and place it on the same device as the model. + # `_parameter_arnoldi` needs to know which device the model is on, and + # will infer it through the device of this random vector + b = _parameter_to( + tuple(torch.randn_like(param) for param in params), + device=self.model_device, + ) + + # perform the arnoldi iteration, see its documentation for what its return + # values are. note that `H` is *not* the Hessian. + qs, H = _parameter_arnoldi( + HVP, + b, + self.arnoldi_dim, + self.arnoldi_tol, + torch.device("cpu") if projection_on_cpu else self.model_device, + show_progress, + ) + + # `ls`` and `vs`` are (approximately) the top eigenvalues / eigenvectors of the + # matrix used (implicitly) to compute Hessian-vector products by the `HVP` + # input to `_parameter_arnoldi`. this matrix is the Hessian of the loss, + # summed over the examples in `dataloader`. note that because the vectors in + # the Hessian-vector product are actually tuples of tensors representing + # parameters, `vs`` is a list of tuples of tensors. note that here, `H` is + # *not* the Hessian (`qs` and `H` together define the Krylov subspace of the + # Hessian) + + ls, vs = _parameter_distill( + qs, H, self.projection_dim, self.hessian_reg, self.hessian_inverse_tol + ) + + # if `vs` were a 2D tensor whose columns contain the top eigenvectors of the + # aforementioned hessian, then `R` would be `vs @ torch.diag(ls ** -0.5)`, i.e. + # scaling each column of `vs` by the corresponding entry in `ls ** -0.5`. + # however, since `vs` is instead a list of tuple of tensors, `R` should be + # a list of tuple of tensors, where each entry in the list is scaled by the + # corresponding entry in `ls ** 0.5`, which we first compute. + ls = (1.0 / ls) ** 0.5 + + # then, scale each entry in `vs` by the corresponding entry in `ls ** 0.5` + # since each entry in `vs` is a tuple of tensors, we use a helper function + # that takes in a tuple of tensors, and a scalar, and multiplies every tensor + # by the scalar. + return [_parameter_multiply(v, l) for (v, l) in zip(vs, ls)] + + def compute_intermediate_quantities( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + aggregate: bool = False, + show_progress: bool = False, + return_on_cpu: bool = True, + test: bool = False, + ) -> Tensor: + r""" + Computes "embedding" vectors for all examples in a single batch, or a + `Dataloader` that yields batches. These embedding vectors are constructed so + that the influence score of a training example on a test example is simply the + dot-product of their corresponding vectors. In both cases, a batch should be + small enough so that a backwards pass for a batch does not lead to + out-of-memory errors. + + In more detail, the embedding vector for an example `x` is + :math`\nabla_\theta L(x)' R`, where :math`R` is as defined in this class' + description. Each element of `R` and :math`\nabla_\theta L(x)` lie in + parameter-space. Therefore, if parameter-space were 1D, so that `R` were + a 2D tensor whose columns are different elements in parameter-space, we would + compute the embeddings for a batch by assembling :math`\nabla_\theta L(x)` for + all examples `x` in the batch as rows in a 2D "batch parameter-gradient" + tensor, and right-multiplying by `R`. However, parameter-space in this + implementation is actually a tuple of tensors. So we do the analogous + computation given this representation of parameter-space. + + If `aggregate` is True, the *sum* of the vectors for all examples is returned, + instead of the vectors for each example. This can be useful for computing the + influence of a given training example on the total loss over a validation + dataset, because due to properties of the dot-product, this influence is the + dot-product of the training example's vector with the sum of the vectors in the + validation dataset. Also, by doing the sum aggregation within this method as + opposed to outside of it (by computing all vectors for the validation dataset, + then taking the sum) allows memory usage to be reduced. + + Args: + inputs_dataset (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, and + and `batch[-1]` are the labels, if any. Here, `model` is model + provided in initialization. This is the same assumption made for + each batch yielded by training dataset `train_dataset`. + aggregate (bool): Whether to return the sum of the vectors for all + examples, as opposed to vectors for each example. + show_progress (bool, optional): Computation of vectors can take a long + time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which + vectors have been computed will be displayed. It will try to + use tqdm if available for advanced features (e.g. time estimation). + Otherwise, it will fallback to a simple output of progress. + Default: False + return_on_cpu (bool, optional): Whether to return the vectors on the cpu + (or if not, the gpu). If None, is set to the device that the model + is on. + Default: None + test (bool, optional): Whether to compute the vectors using the loss + function `test_loss_fn` provided in initialization (instead of + `loss_fn`). This argument does not matter if `test_loss_fn` was + not provided, as in this case, `test_loss_fn` and `loss_fn` are the + same. + + Returns: + intermediate_quantities (Tensor): This is a 2D tensor with shape + `(N, projection_dim)`, where `N` is the total number of examples in + `inputs_dataset`, and `projection_dim` was provided in + initialization. Each row contains the vector for a different + example. + """ + # if `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + if show_progress: + inputs_dataset = _progress_bar_constructor( + self, inputs_dataset, "inputs_dataset", "intermediate quantities" + ) + + # infer model / data device through model. return device is same as that of + # model unless explicitly specified + if return_on_cpu is None: + return_device = self.model_device + else: + return_device = torch.device("cpu") if return_on_cpu else self.model_device + + # choose the correct loss function and reduction type based on `test` + loss_fn = self.test_loss_fn if test else self.loss_fn + reduction_type = self.test_reduction_type if test else self.reduction_type + + # define a helper function that returns the embeddings for a batch + def get_batch_embeddings(batch): + # get gradient + features, labels = tuple(batch[0:-1]), batch[-1] + # `jacobians`` is a tensor of tuples. unlike parameters, however, the first + # dimension is a batch dimension + jacobians = _compute_jacobian_sample_wise_grads_per_batch( + self, features, labels, loss_fn, reduction_type + ) + + # `jacobians`` contains the per-example parameters for a batch. this + # function takes in `params`, a tuple of tensors representing a single + # parameter setting, and for each example, computes the dot-product of its + # per-example parameter with `params`. in other words, given `params`, + # representing a basis vector, this function returns the coordinate of + # each example in the batch along that basis. note that `jacobians` and + # `params` are both tuple of tensors, with the same length. however, a + # tensor in `jacobians` always has dimension 1 greater than the + # corresponding tensor in `params`, because the tensors in `jacobians` have + # a batch dimension (the 1st). to do this computation, the naive way would + # be to convert `jacobians` to a list of tuple of tensors, and use + # `_parameter_dot` to take the dot-product of each element in the list + # with `params` to get a 1D tensor whose length is the batch size. however, + # we can do the same computation without actually creating that list of + # tuple of tensors by using broadcasting. + def get_batch_coordinate(params): + batch_coordinate = 0 + for (_jacobians, param) in zip(jacobians, params): + batch_coordinate += torch.sum( + _jacobians * param.to(device=self.model_device).unsqueeze(0), + dim=tuple(range(1, len(_jacobians.shape))), + ) + return batch_coordinate.to(device=return_device) + + # to get the embedding for the batch, we get the coordinates for the batch + # corresponding to one parameter in `R`. We do this for every parameter in + # `R`, and then concatenate. + return torch.stack( + [get_batch_coordinate(params) for params in self.R], + dim=1, + ) + + # using `get_batch_embeddings` and a helper, return all the vectors or their + # sum, depending on `aggregate` + return _get_dataset_embeddings_intermediate_quantities_influence_function( + get_batch_embeddings, + inputs_dataset, + aggregate, + ) + + @log_usage(skip_self_logging=True) + def influence( # type: ignore[override] + self, + inputs: Tuple, + k: Optional[int] = None, + proponents: bool = True, + show_progress: bool = False, + ) -> Union[Tensor, KMostInfluentialResults]: + """ + This is the key method of this class, and can be run in 2 different modes, + where the mode that is run depends on the arguments passed to this method: + + - influence score mode: This mode is used if `k` is None. This mode computes + the influence score of every example in training dataset `train_dataset` + on every example in the test dataset represented by `inputs`. + - k-most influential mode: This mode is used if `k` is not None, and an int. + This mode computes the proponents or opponents of every example in the + test dataset represented by `inputs`. In particular, for each test example in + the test dataset, this mode computes its proponents (resp. opponents), + which are the indices in the training dataset `train_dataset` of the + training examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. Otherwise, + opponents are computed. For each test example, this method also returns the + actual influence score of each proponent (resp. opponent) on the test + example. + + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test dataset. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + show_progress (bool, optional): For all modes, computation of results + requires "training dataset computations": computations for each + batch in the training dataset `train_dataset`, which may + take a long time. If `show_progress` is true, the progress of + "training dataset computations" will be displayed. In particular, + the number of batches for which computations have been performed + will be displayed. It will try to use tqdm if available for + advanced features (e.g. time estimation). Otherwise, it will + fallback to a simple output of progress. + Default: False + + Returns: + The return value of this method depends on which mode is run. + + - influence score mode: if this mode is run (`k` is None), returns a 2D + tensor `influence_scores` of shape `(input_size, train_dataset_size)`, + where `input_size` is the number of examples in the test dataset, and + `train_dataset_size` is the number of examples in training dataset + `train_dataset`. In other words, `influence_scores[i][j]` is the + influence score of the `j`-th example in `train_dataset` on the `i`-th + example in the test dataset. + - k-most influential mode: if this mode is run (`k` is an int), returns + a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of + shape `(input_size, k)`, where `input_size` is the number of examples in + the test dataset. If computing proponents (resp. opponents), + `indices[i][j]` is the index in training dataset `train_dataset` of the + example with the `j`-th highest (resp. lowest) influence score (out of + the examples in `train_dataset`) on the `i`-th example in the test + dataset. `influence_scores` contains the corresponding influence scores. + In particular, `influence_scores[i][j]` is the influence score of example + `indices[i][j]` in `train_dataset` on example `i` in the test dataset + represented by `inputs`. + """ + return _influence_route_to_helpers( + self, + inputs, + k, + proponents, + show_progress=show_progress, + ) + + def _get_k_most_influential( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + ) -> KMostInfluentialResults: + r""" + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `train_dataset`, If `show_progress` is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `train_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `train_dataset` + on example `i` in the test dataset represented by `inputs`. + """ + desc = ( + None + if not show_progress + else ( + ( + f"Using {self.get_name()} to perform computation for " + f'getting {"proponents" if proponents else "opponents"}. ' + "Processing training batches" + ) + ) + ) + return KMostInfluentialResults( + *_get_k_most_influential_helper( + self.train_dataloader, + functools.partial( + _influence_batch_intermediate_quantities_influence_function, self + ), + inputs, + k, + proponents, + show_progress, + desc, + ) + ) + + def _influence( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: + r""" + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): To compute the influence of examples in + training dataset `train_dataset`, we compute the influence + of each batch. If `show_progress` is true, the progress of this + computation will be displayed. In particular, the number of batches + for which influence has been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + + Returns: + influence_scores (Tensor): Influence scores over the entire + training dataset `train_dataset`. Dimensionality is + (inputs_batch_size, src_dataset_size). For example: + influence_scores[i][j] = the influence score for the j-th training + example to the i-th example in the test dataset. + """ + # turn inputs and targets into a dataset. inputs has already been processed + # so that it should always be unpacked + inputs_dataset = _format_inputs_dataset(inputs) + return _influence_helper_intermediate_quantities_influence_function( + self, inputs_dataset, show_progress + ) + + def self_influence( + self, + inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None, + show_progress: bool = False, + ) -> Tensor: + """ + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Implementation-wise, the self-influence score for an example is simply the + squared norm of the example's "embedding" vector. Therefore, the implementation + leverages `compute_intermediate_quantities`. + + Args: + inputs_dataset (tuple or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. + Default: None + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which + self influence scores have been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + return _self_influence_helper_intermediate_quantities_influence_function( + self, inputs_dataset, show_progress + ) diff --git a/captum/influence/_core/influence.py b/captum/influence/_core/influence.py index 553ab38abb..51b33d0a9c 100644 --- a/captum/influence/_core/influence.py +++ b/captum/influence/_core/influence.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Type from torch.nn import Module from torch.utils.data import Dataset @@ -42,3 +42,15 @@ def influence(self, inputs: Any = None, **kwargs: Any) -> Any: though this may change in the future. """ pass + + @classmethod + def get_name(cls: Type["DataInfluence"]) -> str: + r""" + Create readable class name. Due to the nature of the names of `TracInCPBase` + subclasses, simply returns the class name. For example, for a class called + TracInCP, we return the string TracInCP. + + Returns: + name (str): a readable class name + """ + return cls.__name__ diff --git a/captum/influence/_core/influence_function.py b/captum/influence/_core/influence_function.py new file mode 100644 index 0000000000..6e3540f1e4 --- /dev/null +++ b/captum/influence/_core/influence_function.py @@ -0,0 +1,1314 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import functools +from abc import abstractmethod +from operator import add +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from captum._utils.gradient import _extract_parameters_from_layers +from captum.influence._core.influence import DataInfluence + +from captum.influence._utils.common import ( + _check_loss_fn, + _compute_batch_loss_influence_function_base, + _compute_jacobian_sample_wise_grads_per_batch, + _dataset_fn, + _flatten_params, + _format_inputs_dataset, + _functional_call, + _get_k_most_influential_helper, + _influence_batch_intermediate_quantities_influence_function, + _influence_helper_intermediate_quantities_influence_function, + _influence_route_to_helpers, + _load_flexible_state_dict, + _params_to_names, + _progress_bar_constructor, + _self_influence_helper_intermediate_quantities_influence_function, + _set_active_parameters, + _top_eigen, + _unflatten_params_factory, + KMostInfluentialResults, +) +from captum.log import log_usage +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + + +class InfluenceFunctionBase(DataInfluence): + r""" + `InfluenceFunctionBase` is a base class for implementations which compute the + influence score as defined in the paper "Understanding Black-box Predictions via + Influence Functions" (https://arxiv.org/pdf/1703.04730.pdf). This "infinitesimal" + influence score approximately answers the question if a given training example + were infinitesimally down-weighted and the model re-trained to optimality, how much + would the loss on a given test example change. Mathematically, the aforementioned + influence score is given by :math`\nabla_\theta L(x)' H^{-1} \nabla_\theta L(z)`, + where :math`\nabla_\theta L(x)` is the gradient of the loss, considering only + training example :math`x` with respect to (a subset of) model parameters + :math`\theta`, :math`\nabla_\theta L(z)` is the analogous quantity for a test + example :math`z`, and :math`H` is the Hessian of the (subset of) model parameters + at a given model checkpoint. "Subset of model parameters" refers to the parameters + specified by the `layers` initialization argument; for computational purposes, + we may only consider the gradients / Hessian involving parameters in a subset of + the model's layers. This is a commonly-taken approach in the research literature. + + There can be multiple implementations of this class, because although the paper + defines a particular "infinitesimal" kind of influence score, there can be multiple + ways to compute it, each with different levels of accuracy / scalability. + """ + + def __init__( + self, + model: Module, + train_dataset: Union[Dataset, DataLoader], + checkpoint: str, + checkpoints_load_func: Callable = _load_flexible_state_dict, + layers: Optional[List[str]] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + hessian_dataset: Optional[Union[Dataset, DataLoader]] = None, + test_loss_fn: Optional[Union[Module, Callable]] = None, + sample_wise_grads_per_batch: bool = False, + ) -> None: + """ + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. + checkpoint (str): The path to the checkpoint used to compute influence + scores. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + layers (list[str] or None, optional): A list of layer names for which + gradients should be computed. If `layers` is None, gradients will + be computed for all layers. Otherwise, they will only be computed + for the layers specified in `layers`. + Default: None + loss_fn (Callable, optional): The loss function applied to model. There + are two options for the return type of `loss_fn`. First, `loss_fn` + can be a "per-example" loss function - returns a 1D Tensor of + losses for each example in a batch. `nn.BCELoss(reduction="none")` + would be an "per-example" loss function. Second, `loss_fn` can be + a "reduction" loss function that reduces the per-example losses, + in a batch, and returns a single scalar Tensor. For this option, + the reduction must be the *sum* or the *mean* of the per-example + losses. For instance, `nn.BCELoss(reduction="sum")` is acceptable. + Note for the first option, the `sample_wise_grads_per_batch` + argument must be False, and for the second option, + `sample_wise_grads_per_batch` must be True. Also note that for + the second option, if `loss_fn` has no "reduction" attribute, + the implementation assumes that the reduction is the *sum* of the + per-example losses. If this is not the case, i.e. the reduction + is the *mean*, please set the "reduction" attribute of `loss_fn` + to "mean", i.e. `loss_fn.reduction = "mean"`. + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `train_dataset` and `hessian_dataset`, if they are + of type `Dataset`. `batch_size` should be chosen as large as + possible so that a backwards pass on a batch still fits in memory. + If `train_dataset` and `hessian_dataset`are both of type + `DataLoader`, then `batch_size` is ignored as an argument. + Default: 1 + hessian_dataset (Dataset or Dataloader, optional): The influence score and + self-influence scores this implementation calculates are defined in + terms of the Hessian, i.e. the second-derivative of the model + parameters. This argument provides the dataset used for calculating + the Hessian. It should be smaller than `train_dataset`, which + is the dataset whose examples we want the influence of. If not + provided or none, it will be assumed to be the same as + `train_dataset`. + Default: None + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. Note that the + constraints on both `loss_fn` and `test_loss_fn` both depend on + `sample_wise_grads_per_batch`. This means `loss_fn` and + `test_loss_fn` must either both be "per-example" loss functions, + or both be "reduction" loss functions. If not provided, the loss + function for test examples is assumed to be the same as the loss + function for training examples, i.e. `loss_fn`. + Default: None + sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient + computations w.r.t. model parameters aggregates the results for a + batch and does not allow to access sample-wise gradients w.r.t. + model parameters. This forces us to iterate over each sample in + the batch if we want sample-wise gradients which is computationally + inefficient. We offer an implementation of batch-wise gradient + computations w.r.t. to model parameters which is computationally + more efficient. This implementation can be enabled by setting the + `sample_wise_grad_per_batch` argument to `True`, and should be + enabled if and only if the `loss_fn` argument is a "reduction" loss + function. For example, `nn.BCELoss(reduction="sum")` would be a + valid `loss_fn` if this implementation is enabled (see + documentation for `loss_fn` for more details). Note that our + current implementation enables batch-wise gradient computations + only for a limited number of PyTorch nn.Modules: Conv2D and Linear. + This list will be expanded in the near future. Therefore, please + do not enable this implementation if gradients will be computed + for other kinds of layers. + Default: False + """ + + self.model = model + + self.checkpoint = checkpoint + + self.checkpoints_load_func = checkpoints_load_func + # actually load the checkpoint + checkpoints_load_func(model, checkpoint) + self.loss_fn = loss_fn + # If test_loss_fn not provided, it's assumed to be same as loss_fn + self.test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn + self.sample_wise_grads_per_batch = sample_wise_grads_per_batch + self.batch_size = batch_size + + if not isinstance(train_dataset, DataLoader): + assert isinstance(batch_size, int), ( + "since the `train_dataset` argument was a `Dataset`, " + "`batch_size` must be an int." + ) + self.train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False) + else: + self.train_dataloader = train_dataset + + if hessian_dataset is None: + self.hessian_dataloader = self.train_dataloader + elif not isinstance(hessian_dataset, DataLoader): + assert isinstance(batch_size, int), ( + "since the `shared_dataset` argument was a `Dataset`, " + "`batch_size` must be an int." + ) + self.hessian_dataloader = DataLoader( + hessian_dataset, batch_size, shuffle=False + ) + else: + self.hessian_dataloader = hessian_dataset + + # we check the loss functions in `InfluenceFunctionBase` rather than + # individually in its child classes because we assume all its implementations + # have the same requirements on loss functions, i.e. the type of reductions + # supported. furthermore, these checks are done using a helper function that + # handles all implementations with a `sample_wise_grads_per_batch` + # initialization argument. + + # we save the reduction type for both `loss_fn` and `test_loss_fn` because + # 1) if `sample_wise_grads_per_batch` is true, the reduction type is needed + # to compute per-example gradients, and 2) regardless, reduction type for + # `loss_fn` is needed to compute the Hessian. + + # check `loss_fn` + self.reduction_type = _check_loss_fn( + self, loss_fn, "loss_fn", sample_wise_grads_per_batch + ) + # check `test_loss_fn` if it was provided + if not (test_loss_fn is None): + self.test_reduction_type = _check_loss_fn( + self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch + ) + else: + self.test_reduction_type = self.reduction_type + + self.layer_modules = None + if not (layers is None): + self.layer_modules = _set_active_parameters(model, layers) + + @abstractmethod + def self_influence( + self, + inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None, + show_progress: bool = False, + ) -> Tensor: + """ + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Args: + inputs_dataset (tuple or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress` is true, the progress of this computation will be + displayed. In more detail, this computation will iterate over all + checkpoints (provided as the `checkpoints` initialization argument) + in an outer loop, and iterate over all batches that + `inputs_dataset` represents in an inner loop. Therefore, the + total number of (checkpoint, batch) combinations that need to be + iterated over is + (# of checkpoints x # of batches that `inputs_dataset` represents). + If `show_progress` is True, the total progress of both the outer + iteration over checkpoints and the inner iteration over batches is + displayed. It will try to use tqdm if available for advanced + features (e.g. time estimation). Otherwise, it will fallback to a + simple output of progress. + Default: False + + Returns: + self_influence_scores (Tensor): This is a 1D tensor containing the self + influence scores of all examples in `inputs_dataset`, regardless of + whether it represents a single batch or a `DataLoader` that yields + batches. + """ + pass + + @abstractmethod + def _get_k_most_influential( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + ) -> KMostInfluentialResults: + r""" + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `train_dataset`, If `show_progress` is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `train_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `train_dataset` + on example `i` in the test dataset represented by `inputs`. + """ + pass + + @abstractmethod + def _influence( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: + r""" + Args: + + inputs (tuple[Any, ...]): A batch of examples. Does not represent labels, + which are passed as `targets`. The assumption is that + `model(*inputs)` produces the predictions for the batch. + targets (Tensor, optional): If computing influence scores on a loss + function, these are the labels corresponding to the batch + `inputs`. + Default: None + + Returns: + influence_scores (Tensor): Influence scores over the entire + training dataset `train_dataset`. Dimensionality is + (inputs_batch_size, src_dataset_size). For example: + influence_scores[i][j] = the influence score for the j-th training + example to the i-th input example. + """ + pass + + @abstractmethod + def influence( # type: ignore[override] + self, + inputs: Tuple, + k: Optional[int] = None, + proponents: bool = True, + show_progress: bool = False, + ) -> Union[Tensor, KMostInfluentialResults]: + r""" + This is the key method of this class, and can be run in 2 different modes, + where the mode that is run depends on the arguments passed to this method: + + - influence score mode: This mode is used if `k` is None. This mode computes + the influence score of every example in training dataset `train_dataset` + on every example in the test dataset represented by `inputs`. + - k-most influential mode: This mode is used if `k` is not None, and an int. + This mode computes the proponents or opponents of every example in the + test dataset represented by `inputs`. In particular, for each test example in + the test dataset, this mode computes its proponents (resp. opponents), + which are the indices in the training dataset `train_dataset` of the + training examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. Otherwise, + opponents are computed. For each test example, this method also returns the + actual influence score of each proponent (resp. opponent) on the test + example. + + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test dataset. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + show_progress (bool, optional): For all modes, computation of results + requires "training dataset computations": computations for each + batch in the training dataset `train_dataset`, which may + take a long time. If `show_progress` is true, the progress of + "training dataset computations" will be displayed. In particular, + the number of batches for which computations have been performed + will be displayed. It will try to use tqdm if available for + advanced features (e.g. time estimation). Otherwise, it will + fallback to a simple output of progress. + Default: False + + Returns: + The return value of this method depends on which mode is run. + + - influence score mode: if this mode is run (`k` is None), returns a 2D + tensor `influence_scores` of shape `(input_size, train_dataset_size)`, + where `input_size` is the number of examples in the test dataset, and + `train_dataset_size` is the number of examples in training dataset + `train_dataset`. In other words, `influence_scores[i][j]` is the + influence score of the `j`-th example in `train_dataset` on the `i`-th + example in the test dataset. + - k-most influential mode: if this mode is run (`k` is an int), returns + a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of + shape `(input_size, k)`, where `input_size` is the number of examples in + the test dataset. If computing proponents (resp. opponents), + `indices[i][j]` is the index in training dataset `train_dataset` of the + example with the `j`-th highest (resp. lowest) influence score (out of + the examples in `train_dataset`) on the `i`-th example in the test + dataset. `influence_scores` contains the corresponding influence scores. + In particular, `influence_scores[i][j]` is the influence score of example + `indices[i][j]` in `train_dataset` on example `i` in the test dataset + represented by `inputs`. + """ + pass + + +class IntermediateQuantitiesInfluenceFunction(InfluenceFunctionBase): + """ + Implementations of this class all implement the `compute_intermediate_quantities` + method, which computes the "embedding" vectors for all examples in a test dataset. + These embedding vectors are assumed to have the following properties: + - the influence score of one example on another example, as calculated by the + implementation, is the dot-product of their respective embeddings. + - the self influence score of an example is the squared norm of its embedding. + """ + + @abstractmethod + def compute_intermediate_quantities( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + aggregate: bool = False, + show_progress: bool = False, + return_on_cpu: bool = True, + test: bool = False, + ): + pass + + +def _flatten_forward_factory( + model: nn.Module, + loss_fn: Optional[Union[Module, Callable]], + reduction_type: str, + unflatten_fn: Callable, + param_names: List[str], +): + """ + Given a model, loss function, reduction type of the loss, function that unflattens + 1D tensor input into a tuple of tensors, the name of each tensor in that tuple, + each of which represents a parameter of `model`, and returns a factory. The factory + accepts a batch, and returns a function whose input is the parameters represented + by `param_names`, and output is the total loss of the model with those parameters, + calculated on the batch. The parameter input to the returned function is assumed to + be *flattened* via the inverse of `unflatten_fn`, which takes a tuple of tensors to + a 1D tensor. This returned function, accepting a single flattened 1D parameter, is + useful for computing the parameter gradient involving the batch as a 1D tensor, and + the Hessian involving the batch as a 2D tensor. Both quantities are needed to + calculate the kind of influence scores returned by implementations of + `InfluenceFunctionBase`. + """ + # this is the factory that accepts a batch + def flatten_forward_factory_given_batch(batch): + + # this is the function that factory returns, which is a function of flattened + # parameters + def flattened_forward(flattened_params): + # as everywhere else, the all but the last elements of a batch are + # assumed to correspond to the features, i.e. input to forward function + features, labels = tuple(batch[0:-1]), batch[-1] + + _output = _functional_call( + model, dict(zip(param_names, unflatten_fn(flattened_params))), features + ) + + # compute the total loss for the batch, adjusting the output of + # `loss_fn` based on `reduction_type` + return _compute_batch_loss_influence_function_base( + loss_fn, _output, labels, reduction_type + ) + + return flattened_forward + + return flatten_forward_factory_given_batch + + +def _compute_dataset_func( + inputs_dataset: Union[Tuple[Tensor, ...], DataLoader], + model: Module, + loss_fn: Optional[Union[Module, Callable]], + reduction_type: str, + layer_modules: Optional[List[Module]], + f: Callable, + show_progress: bool, + **f_kwargs, +): + """ + This function is used to compute higher-order functions of a given model's loss + over a given dataset, using the model's current parameters. For example, that + higher-order function `f` could be the Hessian, or a Hessian-vector product. + This function uses the factory returned by `_flatten_forward_factory`, which given + a batch, returns the loss for the batch as a function of flattened parameters. + In particular, for each batch in `inputs_dataset`, this function uses the factory + to obtain `flattened_forward`, which returns the loss for `model`, using the batch. + `flattened_forward`, as well as the flattened parameters for `model`, are used by + argument `f`, a higher-order function, to compute a batch-specific quantity. + For example, `f` could compute the Hessian via `torch.autograd.functional.hessian`, + or compute a Hessian-vector product via `torch.autograd.functional.hvp`. Additional + arguments besides `flattened_forward` and the flattened parameters, i.e. the vector + in Hessian-vector products, can be passed via named arguments. + """ + # extract the parameters in a tuple + params = tuple( + model.parameters() + if layer_modules is None + else _extract_parameters_from_layers(layer_modules) + ) + + # construct functions that can flatten / unflatten tensors, and get + # names of each param in `params`. + # Both are needed for calling `_flatten_forward_factory` + _unflatten_params = _unflatten_params_factory( + tuple([param.shape for param in params]) + ) + param_names = _params_to_names(params, model) + + # prepare factory + factory_given_batch = _flatten_forward_factory( + model, + loss_fn, + reduction_type, + _unflatten_params, + param_names, + ) + + # the function returned by the factor is evaluated at a *flattened* version of + # params, so need to create that + flattened_params = _flatten_params(params) + + # define function of a single batch + def batch_f(batch): + flattened_forward = factory_given_batch(batch) # accepts flattened params + return f(flattened_forward, flattened_params, **f_kwargs) + + # sum up results of `batch_f` + if show_progress: + inputs_dataset = tqdm(inputs_dataset, desc="processing `hessian_dataset` batch") + + return _dataset_fn(inputs_dataset, batch_f, add) + + +def _get_dataset_embeddings_intermediate_quantities_influence_function( + batch_embeddings_fn: Callable, + inputs_dataset: DataLoader, + aggregate: bool, +): + """ + given `batch_embeddings_fn`, which produces the embeddings for a given batch, + returns either the embeddings for an entire dataset (if `aggregate` is false), + or the sum of the embeddings for an entire dataset (if `aggregate` is true). + """ + # if aggregate is false, we concatenate the embeddings for all batches + if not aggregate: + return torch.cat( + [batch_embeddings_fn(batch) for batch in inputs_dataset], dim=0 + ) + else: + # if aggregate is True, we return the sum of all embeddings for all + # batches. we do this by summing over each batch, and then summing over all + # batches. + inputs_dataset_iter = iter(inputs_dataset) + + batch = next(inputs_dataset_iter) + total_embedding = torch.sum(batch_embeddings_fn(batch), dim=0) + + for batch in inputs_dataset_iter: + total_embedding += torch.sum(batch_embeddings_fn(batch), dim=0) + + # we unsqueeze because regardless of aggregate, the returned tensor should + # be 2D. + return total_embedding.unsqueeze(0) + + +class NaiveInfluenceFunction(IntermediateQuantitiesInfluenceFunction): + r""" + This is a computationally-inefficient implementation that computes the type of + "infinitesimal" influence scores defined in the paper "Understanding Black-box + Predictions via Influence Functions" by Koh et al + (https://arxiv.org/pdf/1703.04730.pdf). The computational bottleneck in computing + infinitesimal influence scores is computing inverse Hessian-vector products, as can + be seen from its definition in `InfluenceFunctionBase`. This implementation is + inefficient / naive in that it explicitly forms the Hessian :math`H`, unlike other + implementations which compute inverse Hessian-vector products without explicitly + forming the Hessian. The purpose of this implementation is to have a way to + generate the "ground-truth" influence scores, to which other implementations, + which are more efficient but return only approximations of the influence score, can + be compared. + + This implementation computes a low-rank approximation of the inverse Hessian, i.e. + a tall and skinny (with width k) matrix :math`R` such that + :math`H^{-1} \approx RR'`, where k is small. In particular, let :math`L` be the + matrix of width k whose columns contain the top-k eigenvectors of :math`H`, and let + :math`V` be the k by k matrix whose diagonals contain the corresponding eigenvalues. + This implementation lets :math`R=LV^{-1}L'`. Thus, the core computational step is + computing the top-k eigenvalues / eigenvectors. + + This low-rank approximation is useful for several reasons: + - It avoids numerical issues associated with inverting small eigenvalues. + - Since the influence score is given by + :math`\nabla_\theta L(x)' H^{-1} \nabla_\theta L(z)`, which is approximated by + :math`(\nabla_\theta L(x)' R) (\nabla_\theta L(z)' R)`, we can compute an + "influence embedding" for a given example :math`x`, :math`\nabla_\theta L(x)' R`, + such that the influence score of one example on another is approximately the + dot-product of their respective embeddings. + + This implementation is "naive" in that it computes the top-k eigenvalues / + eigenvectors by explicitly forming the Hessian, converting it to a 2D tensor, + computing its eigenvectors / eigenvalues, and then sorting. See documentation of the + `_retrieve_projections_naive_influence_function` method for more details. + """ + + def __init__( + self, + model: Module, + train_dataset: Union[Dataset, DataLoader], + checkpoint: str, + checkpoints_load_func: Callable = _load_flexible_state_dict, + layers: Optional[List[str]] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + batch_size: Union[int, None] = 1, + hessian_dataset: Optional[Union[Dataset, DataLoader]] = None, + test_loss_fn: Optional[Union[Module, Callable]] = None, + sample_wise_grads_per_batch: bool = False, + projection_dim: int = 50, + seed: int = 42, + hessian_reg: float = 1e-6, + hessian_inverse_tol: float = 1e-5, + projection_on_cpu: bool = True, + show_progress: bool = False, + ) -> None: + """ + Args: + model (torch.nn.Module): An instance of pytorch model. This model should + define all of its layers as attributes of the model. + train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader): + In the `influence` method, we either compute the influence score of + training examples on examples in a test batch, or self influence + scores for those training examples, depending on which mode is used. + This argument represents the training dataset containing those + training examples. In order to compute those influence scores, we + will create a Pytorch DataLoader yielding batches of training + examples that is then used for processing. If this argument is + already a Pytorch Dataloader, that DataLoader can be directly + used for processing. If it is instead a Pytorch Dataset, we will + create a DataLoader using it, with batch size specified by + `batch_size`. For efficiency purposes, the batch size of the + DataLoader used for processing should be as large as possible, but + not too large, so that certain intermediate quantities created + from a batch still fit in memory. Therefore, if + `train_dataset` is a Dataset, `batch_size` should be large. + If `train_dataset` was already a DataLoader to begin with, + it should have been constructed to have a large batch size. It is + assumed that the Dataloader (regardless of whether it is created + from a Pytorch Dataset or not) yields tuples. For a `batch` that is + yielded, of length `L`, it is assumed that the forward function of + `model` accepts `L-1` arguments, and the last element of `batch` is + the label. In other words, `model(*batch[:-1])` gives the output of + `model`, and `batch[-1]` are the labels for the batch. + checkpoint (str): The path to the checkpoint used to compute influence + scores. + checkpoints_load_func (Callable, optional): The function to load a saved + checkpoint into a model to update its parameters, and get the + learning rate if it is saved. By default uses a utility to load a + model saved as a state dict. + Default: _load_flexible_state_dict + layers (list[str] or None, optional): A list of layer names for which + gradients should be computed. If `layers` is None, gradients will + be computed for all layers. Otherwise, they will only be computed + for the layers specified in `layers`. + Default: None + loss_fn (Callable, optional): The loss function applied to model. For now, + we require it to be a "reduction='none'" loss function. For + example, `BCELoss(reduction='none')` would be acceptable, but + `BCELoss(reduction='sum')` would not. + batch_size (int or None, optional): Batch size of the DataLoader created to + iterate through `train_dataset` and `hessian_dataset`, if they are + of type `Dataset`. `batch_size` should be chosen as large as + possible so that a backwards pass on a batch still fits in memory. + If `train_dataset` and `hessian_dataset`are both of type + `DataLoader`, then `batch_size` is ignored as an argument. + Default: 1 + hessian_dataset (Dataset or Dataloader, optional): The influence score and + self-influence scores this implementation calculates are defined in + terms of the Hessian, i.e. the second-derivative of the model + parameters. This argument provides the dataset used for calculating + the Hessian. It should be smaller than `train_dataset`, which + is the dataset whose examples we want the influence of. If not + provided or none, it will be assumed to be the same as + `train_dataset`. + Default: None + test_loss_fn (Callable, optional): In some cases, one may want to use a + separate loss functions for training examples, i.e. those in + `train_dataset`, and for test examples, i.e. those + represented by the `inputs` and `targets` arguments to the + `influence` method. For example, if one wants to calculate the + influence score of a training example on a test example's + prediction for a fixed class, `test_loss_fn` could map from the + logits for all classes to the logits for a fixed class. + `test_loss_fn` needs satisfy the same constraints as `loss_fn`. + Thus, the same checks that we apply to `loss_fn` are also applied + to `test_loss_fn`, if the latter is provided. Note that the + constraints on both `loss_fn` and `test_loss_fn` both depend on + `sample_wise_grads_per_batch`. This means `loss_fn` and + `test_loss_fn` must either both be "per-example" loss functions, + or both be "reduction" loss functions. If not provided, the loss + function for test examples is assumed to be the same as the loss + function for training examples, i.e. `loss_fn`. + Default: None + sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient + computations w.r.t. model parameters aggregates the results for a + batch and does not allow to access sample-wise gradients w.r.t. + model parameters. This forces us to iterate over each sample in + the batch if we want sample-wise gradients which is computationally + inefficient. We offer an implementation of batch-wise gradient + computations w.r.t. to model parameters which is computationally + more efficient. This implementation can be enabled by setting the + `sample_wise_grad_per_batch` argument to `True`, and should be + enabled if and only if the `loss_fn` argument is a "reduction" loss + function. For example, `nn.BCELoss(reduction="sum")` would be a + valid `loss_fn` if this implementation is enabled (see + documentation for `loss_fn` for more details). Note that our + current implementation enables batch-wise gradient computations + only for a limited number of PyTorch nn.Modules: Conv2D and Linear. + This list will be expanded in the near future. Therefore, please + do not enable this implementation if gradients will be computed + for other kinds of layers. + Default: False + projection_dim (int, optional): This implementation produces a low-rank + approximation of the (inverse) Hessian. This is the rank of that + approximation, and also corresponds to the dimension of the + "influence embeddings" produced by the + `compute_intermediate_quantities` method. + Default: 50 + seed (int, optional): This implementation has a source of randomness - the + initialization basis to the Arnoldi iteration. This seed is used + to make that randomness reproducible. + Default: 42 + hessian_reg (float, optional): We add an entry to the hessian's diagonal + entries before computing its eigenvalues / eigenvectors. + This is that entry. + Default: 1e-6 + hessian_inverse_tol: (float) The tolerance to use when computing the + pseudo-inverse of the (square root of) hessian. + Default: 1e-6 + projection_on_cpu (bool, optional): Whether to move the projection, + i.e. low-rank approximation of the inverse Hessian, to cpu, to save + gpu memory. + Default: True + show_progress (bool, optional): This implementation explicitly computes the + Hessian over batches in `hessian_dataloader` (and sums them) which + can take a long time. If `show_progress` is true, the number of + batches for which the Hessian has been computed will be displayed. + It will try to use tqdm if available for advanced features (e.g. + time estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + InfluenceFunctionBase.__init__( + self, + model, + train_dataset, + checkpoint, + checkpoints_load_func, + layers, + loss_fn, + batch_size, + hessian_dataset, + test_loss_fn, + sample_wise_grads_per_batch, + ) + + self.projection_dim = projection_dim + torch.manual_seed(seed) # for reproducibility + + self.hessian_reg = hessian_reg + self.hessian_inverse_tol = hessian_inverse_tol + + # infer the device the model is on. all parameters are assumed to be on the + # same device + self.model_device = next(model.parameters()).device + + self.R = self._retrieve_projections_naive_influence_function( + self.hessian_dataloader, + projection_on_cpu, + show_progress, + ) + + def _retrieve_projections_naive_influence_function( + self, + dataloader: DataLoader, + projection_on_cpu: bool, + show_progress: bool, + ) -> Tensor: + r""" + Returns the matrix `R` described in the documentation for + `NaiveInfluenceFunction`. In short, `R` has the property that + :math`H^{-1} \approx RR'`, where `H` is the Hessian. Since this is a "naive" + implementation, it does so by explicitly forming the Hessian, converting + it to a 2D tensor, and computing its eigenvectors / eigenvalues, before + filtering out some eigenvalues and then inverting them. The returned matrix + `R` represents a set of parameters in parameter space. Since the Hessian + is obtained by first flattening the parameters, each column of `R` corresponds + to a *flattened* parameter in parameter space. + + Args: + dataloader (DataLoader): The returned matrix `R` gives a low-rank + approximation of the Hessian `H`. This dataloader defines the + dataset used to compute the Hessian that is being approximated. + projection_on_cpu (bool, optional): Whether to move the projection, + i.e. low-rank approximation of the inverse Hessian, to cpu, to save + gpu memory. + show_progress (bool): Computing the Hessian that is being approximated + requires summing up the Hessians computed using different batches, + which may take a long time. If `show_progress` is true, the number + of batches that have been processed will be displayed. It will try + to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + + Returns: + R (Tensor): Tall and skinny tensor with width `projection_dim` + (initialization argument). Each column corresponds to a flattened + parameter in parameter-space. `R` has the property that + :math`H^{-1} \approx RR'`. + """ + # compute the hessian using the dataloader. hessian is always computed using + # the training loss function. H is 2D, with each column / row corresponding to + # a different parameter. we cannot directly use + # `torch.autograd.functional.hessian`, because it does not return a 2D tensor. + # instead, to compute H, we first create a function that accepts *flattened* + # model parameters (i.e. a 1D tensor), and outputs the loss of `self.model`, + # using those parameters, aggregated over `dataloader`. this function is then + # passed to `torch.autograd.functional.hessian`. because its input is 1D, the + # resulting hessian is 2D, as desired. all this functionality is handled by + # `_compute_dataset_func`. + H = _compute_dataset_func( + dataloader, + self.model, + self.loss_fn, + self.reduction_type, + self.layer_modules, + torch.autograd.functional.hessian, + show_progress, + ) + + # H is approximately `vs @ torch.diag(ls) @ vs.T``, using eigendecomposition + ls, vs = _top_eigen( + H, self.projection_dim, self.hessian_reg, self.hessian_inverse_tol + ) + + # if no positive eigenvalues exist, we cannot compute a low-rank + # approximation of the square root of the hessian H, so raise exception + if len(ls) == 0: + raise Exception( + "Hessian has no positive " + "eigenvalues, so cannot take its square root." + ) + + # `R` is `vs @ torch.diag(ls ** -0.5)`, since H^{-1} is approximately + # `vs @ torch.diag(ls ** -1) @ vs.T` + # see https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix#Matrix_inverse_via_eigendecomposition # noqa: E501 + # for details, which mentions that discarding small eigenvalues (as done in + # `_top_eigen`) reduces noisiness of the inverse. + ls = (1.0 / ls) ** 0.5 + return (ls.unsqueeze(0) * vs).to( + device=torch.device("cpu") if projection_on_cpu else self.model_device + ) + + def compute_intermediate_quantities( + self, + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + aggregate: bool = False, + show_progress: bool = False, + return_on_cpu: bool = True, + test: bool = False, + ) -> Tensor: + r""" + Computes "embedding" vectors for all examples in a single batch, or a + `Dataloader` that yields batches. These embedding vectors are constructed so + that the influence score of a training example on a test example is simply the + dot-product of their corresponding vectors. In both cases, a batch should be + small enough so that a backwards pass for a batch does not lead to + out-of-memory errors. + + In more detail, the embedding vector for an example `x` is + :math`\nabla_\theta L(x)' R`, where :math`R` is as defined in this class' + description. The embeddings for a batch of examples are computed by assembling + :math`\nabla_\theta L(x)` for all examples `x` in the batch as rows in a 2D + tensor, and right-multiplying by `R`. + + If `aggregate` is True, the *sum* of the vectors for all examples is returned, + instead of the vectors for each example. This can be useful for computing the + influence of a given training example on the total loss over a validation + dataset, because due to properties of the dot-product, this influence is the + dot-product of the training example's vector with the sum of the vectors in the + validation dataset. Also, by doing the sum aggregation within this method as + opposed to outside of it (by computing all vectors for the validation dataset, + then taking the sum) allows memory usage to be reduced. + + Args: + inputs_dataset (Tuple, or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, and + and `batch[-1]` are the labels, if any. Here, `model` is model + provided in initialization. This is the same assumption made for + each batch yielded by training dataset `train_dataset`. + aggregate (bool): Whether to return the sum of the vectors for all + examples, as opposed to vectors for each example. + show_progress (bool, optional): Computation of vectors can take a long + time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which + vectors have been computed will be displayed. It will try to + use tqdm if available for advanced features (e.g. time estimation). + Otherwise, it will fallback to a simple output of progress. + Default: False + return_on_cpu (bool, optional): Whether to return the vectors on the cpu. + If None or False, is set to the device that the model is on. + Default: None + test (bool, optional): Whether to compute the vectors using the loss + function `test_loss_fn` provided in initialization (instead of + `loss_fn`). This argument does not matter if `test_loss_fn` was + not provided, as in this case, `test_loss_fn` and `loss_fn` are the + same. + + Returns: + intermediate_quantities (Tensor): This is a 2D tensor with shape + `(N, projection_dim)`, where `N` is the total number of examples in + `inputs_dataset`, and `projection_dim` was provided in + initialization. Each row contains the vector for a different + example. + """ + # if `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + if show_progress: + inputs_dataset = _progress_bar_constructor( + self, inputs_dataset, "inputs_dataset", "intermediate quantities" + ) + # infer model / data device through model + if return_on_cpu is None or (not return_on_cpu): + return_device = self.model_device + else: + return_device = torch.device("cpu") + + # as described in the description for `NaiveInfluenceFunction`, the embedding + # for an example `x` is :math`\nabla_\theta L(x)' R`. + # `_basic_computation_naive_influence_function` returns a 2D tensor where + # each row is :math`\nabla_\theta L(x)'` for a different example `x` in a + # batch. therefore, we right-multiply its output with `R` to get the embeddings + # for a batch, and then concatenate the per-batch embeddings to get embeddings + # for the entire dataset. + + # choose the correct loss function and reduction type based on `test` + loss_fn = self.test_loss_fn if test else self.loss_fn + reduction_type = self.test_reduction_type if test else self.reduction_type + + # define a helper function that returns the embeddings for a batch + def get_batch_embeddings(batch): + # if `self.R` is on cpu, and `self.model_device` was not cpu, this implies + # `self.R` was too large to fit in gpu memory, and we should do the matrix + # multiplication of the batch jacobians with `self.R` separately for each + # column of `self.R`, to avoid moving the entire `self.R` to gpu all at + # once and running out of gpu memory + batch_jacobians = _basic_computation_naive_influence_function( + self, batch[0:-1], batch[-1], loss_fn, reduction_type + ) + if self.R.device == torch.device( + "cpu" + ) and self.model_device != torch.device("cpu"): + return torch.stack( + [ + torch.matmul(batch_jacobians, R_col.to(batch_jacobians.device)) + for R_col in self.R.T + ], + dim=1, + ).to(return_device) + else: + return torch.matmul(batch_jacobians, self.R).to(device=return_device) + + # using `get_batch_embeddings` and a helper, return all the vectors or their + # sum, depending on `aggregate` + return _get_dataset_embeddings_intermediate_quantities_influence_function( + get_batch_embeddings, + inputs_dataset, + aggregate, + ) + + @log_usage(skip_self_logging=True) + def influence( # type: ignore[override] + self, + inputs: Tuple, + k: Optional[int] = None, + proponents: bool = True, + show_progress: bool = False, + ) -> Union[Tensor, KMostInfluentialResults]: + """ + This is the key method of this class, and can be run in 2 different modes, + where the mode that is run depends on the arguments passed to this method: + + - influence score mode: This mode is used if `k` is None. This mode computes + the influence score of every example in training dataset `train_dataset` + on every example in the test batch represented by `inputs`. + - k-most influential mode: This mode is used if `k` is not None, and an int. + This mode computes the proponents or opponents of every example in the + test batch represented by `inputs`. In particular, for each test example in + the test batch, this mode computes its proponents (resp. opponents), + which are the indices in the training dataset `train_dataset` of the + training examples with the `k` highest (resp. lowest) influence scores on the + test example. Proponents are computed if `proponents` is True. Otherwise, + opponents are computed. For each test example, this method also returns the + actual influence score of each proponent (resp. opponent) on the test + example. + + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): If not provided or `None`, the influence score mode will + be run. Otherwise, the k-most influential mode will be run, + and `k` is the number of proponents / opponents to return per + example in the test batch. + Default: None + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`), if running in k-most influential + mode. + Default: True + show_progress (bool, optional): For all modes, computation of results + requires "training dataset computations": computations for each + batch in the training dataset `train_dataset`, which may + take a long time. If `show_progress` is true, the progress of + "training dataset computations" will be displayed. In particular, + the number of batches for which computations have been performed + will be displayed. It will try to use tqdm if available for + advanced features (e.g. time estimation). Otherwise, it will + fallback to a simple output of progress. + Default: False + + Returns: + The return value of this method depends on which mode is run. + + - influence score mode: if this mode is run (`k` is None), returns a 2D + tensor `influence_scores` of shape `(input_size, train_dataset_size)`, + where `input_size` is the number of examples in the test dataset, and + `train_dataset_size` is the number of examples in training dataset + `train_dataset`. In other words, `influence_scores[i][j]` is the + influence score of the `j`-th example in `train_dataset` on the `i`-th + example in the test batch. + - k-most influential mode: if this mode is run (`k` is an int), returns + a namedtuple `(indices, influence_scores)`. `indices` is a 2D tensor of + shape `(input_size, k)`, where `input_size` is the number of examples in + the test batch. If computing proponents (resp. opponents), + `indices[i][j]` is the index in training dataset `train_dataset` of the + example with the `j`-th highest (resp. lowest) influence score (out of + the examples in `train_dataset`) on the `i`-th example in the test + batch. `influence_scores` contains the corresponding influence scores. + In particular, `influence_scores[i][j]` is the influence score of example + `indices[i][j]` in `train_dataset` on example `i` in the test batch + represented by `inputs`. + """ + + return _influence_route_to_helpers( + self, + inputs, + k, + proponents, + show_progress=show_progress, + ) + + def _get_k_most_influential( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + k: int = 5, + proponents: bool = True, + show_progress: bool = False, + ) -> KMostInfluentialResults: + r""" + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + k (int, optional): The number of proponents or opponents to return per test + example. + Default: 5 + proponents (bool, optional): Whether seeking proponents (`proponents=True`) + or opponents (`proponents=False`) + Default: True + show_progress (bool, optional): To compute the proponents (or opponents) + for the batch of examples, we perform computation for each batch in + training dataset `train_dataset`, If `show_progress` is + true, the progress of this computation will be displayed. In + particular, the number of batches for which the computation has + been performed will be displayed. It will try to use tqdm if + available for advanced features (e.g. time estimation). Otherwise, + it will fallback to a simple output of progress. + Default: False + + Returns: + (indices, influence_scores) (namedtuple): `indices` is a torch.long Tensor + that contains the indices of the proponents (or opponents) for each + test example. Its dimension is `(inputs_batch_size, k)`, where + `inputs_batch_size` is the number of examples in `inputs`. For + example, if `proponents==True`, `indices[i][j]` is the index of the + example in training dataset `train_dataset` with the + k-th highest influence score for the j-th example in `inputs`. + `indices` is a `torch.long` tensor so that it can directly be used + to index other tensors. Each row of `influence_scores` contains the + influence scores for a different test example, in sorted order. In + particular, `influence_scores[i][j]` is the influence score of + example `indices[i][j]` in training dataset `train_dataset` + on example `i` in the test dataset represented by `inputs`. + """ + desc = ( + None + if not show_progress + else ( + ( + f"Using {self.get_name()} to perform computation for " + f'getting {"proponents" if proponents else "opponents"}. ' + "Processing training batches" + ) + ) + ) + return KMostInfluentialResults( + *_get_k_most_influential_helper( + self.train_dataloader, + functools.partial( + _influence_batch_intermediate_quantities_influence_function, self + ), + inputs, + k, + proponents, + show_progress, + desc, + ) + ) + + def _influence( + self, + inputs: Union[Tuple[Any, ...], DataLoader], + show_progress: bool = False, + ) -> Tensor: + r""" + Args: + + inputs (tuple): `inputs` is the test batch and is a tuple of + any, where the last element is assumed to be the labels for the + batch. That is, `model(*batch[0:-1])` produces the output for + `model`, and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset` - please see its documentation in `__init__` for + more details on the assumed structure of a batch. + show_progress (bool, optional): To compute the influence of examples in + training dataset `train_dataset`, we compute the influence + of each batch. If `show_progress` is true, the progress of this + computation will be displayed. In particular, the number of batches + for which influence has been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + + Returns: + influence_scores (Tensor): Influence scores over the entire + training dataset `train_dataset`. Dimensionality is + (inputs_batch_size, src_dataset_size). For example: + influence_scores[i][j] = the influence score for the j-th training + example to the i-th example in the test dataset. + """ + # turn inputs and targets into a dataset. inputs has already been processed + # so that it should always be unpacked + inputs_dataset = _format_inputs_dataset(inputs) + return _influence_helper_intermediate_quantities_influence_function( + self, inputs_dataset, show_progress + ) + + @log_usage(skip_self_logging=True) + def self_influence( + self, + inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]] = None, + show_progress: bool = False, + ) -> Tensor: + """ + Computes self influence scores for the examples in `inputs_dataset`, which is + either a single batch or a Pytorch `DataLoader` that yields batches. Therefore, + the computed self influence scores are *not* for the examples in training + dataset `train_dataset` (unlike when computing self influence scores using the + `influence` method). Note that if `inputs_dataset` is a single batch, this + will call `model` on that single batch, and if `inputs_dataset` yields + batches, this will call `model` on each batch that is yielded. Therefore, + please ensure that for both cases, the batch(es) that `model` is called + with are not too large, so that there will not be an out-of-memory error. + + Implementation-wise, the self-influence score for an example is simply the + squared norm of the example's "embedding" vector. Therefore, the implementation + leverages `compute_intermediate_quantities`. + + Args: + inputs_dataset (tuple or DataLoader): Either a single tuple of any, or a + `DataLoader`, where each batch yielded is a tuple of any. In + either case, the tuple represents a single batch, where the last + element is assumed to be the labels for the batch. That is, + `model(*batch[0:-1])` produces the output for `model`, + and `batch[-1]` are the labels, if any. This is the same + assumption made for each batch yielded by training dataset + `train_dataset`. + Default: None + show_progress (bool, optional): Computation of self influence scores can + take a long time if `inputs_dataset` represents many examples. If + `show_progress`is true, the progress of this computation will be + displayed. In particular, the number of batches for which + self influence scores have been computed will be displayed. It will + try to use tqdm if available for advanced features (e.g. time + estimation). Otherwise, it will fallback to a simple output of + progress. + Default: False + """ + return _self_influence_helper_intermediate_quantities_influence_function( + self, inputs_dataset, show_progress + ) + + +def _basic_computation_naive_influence_function( + influence_inst: InfluenceFunctionBase, + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = None, +) -> Tensor: + """ + This computes the per-example parameter gradients for a batch, flattened into a + 2D tensor where the first dimension is batch dimension. This is used by + `NaiveInfluenceFunction` which computes embedding vectors for each example by + projecting their parameter gradients. + """ + # `jacobians` contains one tensor for each parameter we compute jacobians for. + # the first dimension of each tensor is the batch dimension, and the remaining + # dimensions correspond to the parameter, so that for the tensor corresponding + # to parameter `p`, its shape is `(batch_size, *p.shape)` + jacobians = _compute_jacobian_sample_wise_grads_per_batch( + influence_inst, inputs, targets, loss_fn, reduction_type + ) + + return torch.stack( + [ + _flatten_params(tuple(jacobian[i] for jacobian in jacobians)) + for i in range(len(next(iter(jacobians)))) + ], + dim=0, + ) diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index c66e8ca7b4..5ef223d43c 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -4,34 +4,24 @@ import warnings from abc import abstractmethod from os.path import join -from typing import ( - Any, - Callable, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union import torch from captum._utils.av import AV -from captum._utils.common import _get_module_from_name, _parse_version -from captum._utils.gradient import ( - _compute_jacobian_wrt_params, - _compute_jacobian_wrt_params_with_sample_wise_trick, -) +from captum._utils.common import _parse_version from captum._utils.progress import NullProgress, progress from captum.influence._core.influence import DataInfluence from captum.influence._utils.common import ( _check_loss_fn, + _compute_jacobian_sample_wise_grads_per_batch, _format_inputs_dataset, _get_k_most_influential_helper, _gradient_dot_product, + _influence_route_to_helpers, _load_flexible_state_dict, _self_influence_by_batches_helper, + _set_active_parameters, + KMostInfluentialResults, ) from captum.log import log_usage from torch import Tensor @@ -69,24 +59,6 @@ """ -class KMostInfluentialResults(NamedTuple): - """ - This namedtuple stores the results of using the `influence` method. This method - is implemented by all subclasses of `TracInCPBase` to calculate - proponents / opponents. The `indices` field stores the indices of the - proponents / opponents for each example in the test dataset. For example, if - finding opponents, `indices[i][j]` stores the index in the training data of the - example with the `j`-th highest influence score on the `i`-th example in the test - dataset. Similarly, the `influence_scores` field stores the actual influence - scores, so that `influence_scores[i][j]` is the influence score of example - `indices[i][j]` in the training data on example `i` of the test dataset. - Please see `TracInCPBase.influence` for more details. - """ - - indices: Tensor - influence_scores: Tensor - - class TracInCPBase(DataInfluence): """ To implement the `influence` method, classes inheriting from `TracInCPBase` will @@ -448,34 +420,6 @@ def get_name(cls: Type["TracInCPBase"]) -> str: return cls.__name__ -def _influence_route_to_helpers( - influence_instance: TracInCPBase, - inputs: Union[Tuple[Any, ...], DataLoader], - k: Optional[int] = None, - proponents: bool = True, - **kwargs, -) -> Union[Tensor, KMostInfluentialResults]: - """ - This is a helper function called by `TracInCP.influence` and - `TracInCPFast.influence`. Those methods share a common logic in that they assume - an instance of their respective classes implement 2 private methods - (``_influence`, `_get_k_most_influential`), and the logic of - which private method to call is common, as described in the documentation of the - `influence` method. The arguments and return values of this function are the exact - same as the `influence` method. Note that `influence_instance` refers to the - instance for which the `influence` method was called. - """ - if k is None: - return influence_instance._influence(inputs, **kwargs) - else: - return influence_instance._get_k_most_influential( - inputs, - k, - proponents, - **kwargs, - ) - - class TracInCP(TracInCPBase): def __init__( self, @@ -630,23 +574,7 @@ def __init__( """ self.layer_modules = None if layers is not None: - assert isinstance(layers, List), "`layers` should be a list!" - assert len(layers) > 0, "`layers` cannot be empty!" - assert isinstance( - layers[0], str - ), "`layers` should contain str layer names." - self.layer_modules = [ - _get_module_from_name(self.model, layer) for layer in layers - ] - for layer, layer_module in zip(layers, self.layer_modules): - for name, param in layer_module.named_parameters(): - if not param.requires_grad: - warnings.warn( - "Setting required grads for layer: {}, name: {}".format( - ".".join(layer), name - ) - ) - param.requires_grad = True + self.layer_modules = _set_active_parameters(model, layers) @log_usage() def influence( # type: ignore[override] @@ -1463,19 +1391,6 @@ def _basic_computation_tracincp( argument is only used if `sample_wise_grads_per_batch` was true in initialization. """ - if self.sample_wise_grads_per_batch: - return _compute_jacobian_wrt_params_with_sample_wise_trick( - self.model, - inputs, - targets, - loss_fn, - reduction_type, - self.layer_modules, - ) - return _compute_jacobian_wrt_params( - self.model, - inputs, - targets, - loss_fn, - self.layer_modules, + return _compute_jacobian_sample_wise_grads_per_batch( + self, inputs, targets, loss_fn, reduction_type ) diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index 4acfabcd42..d2bde2e8da 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -1,19 +1,38 @@ #!/usr/bin/env python3 import warnings -from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union +from functools import reduce +from typing import ( + Any, + Callable, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) import torch import torch.nn as nn -from captum._utils.common import _parse_version +from captum._utils.common import _get_module_from_name, _parse_version +from captum._utils.gradient import ( + _compute_jacobian_wrt_params, + _compute_jacobian_wrt_params_with_sample_wise_trick, +) from captum._utils.progress import progress -if TYPE_CHECKING: - from captum.influence._core.tracincp import TracInCPBase - from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset +if TYPE_CHECKING: + from captum.influence._core.influence_function import ( + InfluenceFunctionBase, + IntermediateQuantitiesInfluenceFunction, + ) + from captum.influence._core.tracincp import TracInCP, TracInCPBase + def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor: r""" @@ -422,7 +441,7 @@ def _self_influence_by_batches_helper( def _check_loss_fn( - influence_instance: "TracInCPBase", + influence_instance: Union["TracInCPBase", "InfluenceFunctionBase"], loss_fn: Optional[Union[Module, Callable]], loss_fn_name: str, sample_wise_grads_per_batch: Optional[bool] = None, @@ -505,3 +524,556 @@ def _check_loss_fn( ) return reduction_type + + +def _set_active_parameters(model: Module, layers: List[str]) -> List[Module]: + """ + sets relevant parameters, as indicated by `layers`, to have `requires_grad=True`, + and returns relevant modules. + """ + assert isinstance(layers, List), "`layers` should be a list!" + assert len(layers) > 0, "`layers` cannot be empty!" + assert isinstance(layers[0], str), "`layers` should contain str layer names." + layer_modules = [_get_module_from_name(model, layer) for layer in layers] + for layer, layer_module in zip(layers, layer_modules): + for name, param in layer_module.named_parameters(): + if not param.requires_grad: + warnings.warn( + "Setting required grads for layer: {}, name: {}".format( + ".".join(layer), name + ) + ) + param.requires_grad = True + return layer_modules + + +def _progress_bar_constructor( + influence_inst: "InfluenceFunctionBase", + inputs_dataset: DataLoader, + quantities_name: str, + dataset_name: str = "inputs_dataset", +): + # Try to determine length of progress bar if possible, with a default + # of `None`. + inputs_dataset_len = None + try: + inputs_dataset_len = len(inputs_dataset) + except TypeError: + warnings.warn( + f"Unable to determine the number of batches in " + f"`{dataset_name}`. Therefore, if showing the progress " + f"of the computation of {quantities_name}, " + "only the number of batches processed can be " + "displayed, and not the percentage completion of the computation, " + "nor any time estimates." + ) + + return progress( + inputs_dataset, + desc=( + f"Using {influence_inst.get_name()} to compute {quantities_name}. " + "Processing batch" + ), + total=inputs_dataset_len, + ) + + +def _params_to_names(params: Iterable[nn.Parameter], model: nn.Module) -> List[str]: + """ + Given an iterable of parameters, `params` of a model, `model`, returns the names of + the parameters from the perspective of `model`. This is useful if, given + parameters for which we do not know the name, want to pass them as a dict + to a function of those parameters, i.e. `torch.nn.utils._stateless`. + """ + param_id_to_name = { + id(param): param_name for (param_name, param) in model.named_parameters() + } + return [param_id_to_name[id(param)] for param in params] + + +def _flatten_params(_params: Tuple[Tensor, ...]) -> Tensor: + """ + Given a tuple of tensors, which is how Pytorch represents parameters of a model, + flattens it into a single tensor. This is useful if we want to do matrix operations + on the parameters of a model, i.e. invert its Hessian, or compute dot-product of + parameter-gradients. Note that flattening and then passing to standard linear + algebra operations may not be the most efficient way to perform them. + """ + return torch.cat([_param.view(-1) for _param in _params]) + + +def _unflatten_params_factory( + param_shapes: Union[List[Tuple[int, ...]], Tuple[Tensor, ...]] +): + """ + returns a function which is the inverse of `_flatten_params` + """ + + def _unflatten_params(flattened_params): + params = [] + offset = 0 + for shape in param_shapes: + length = 1 + for s in shape: + length *= s + params.append(flattened_params[offset : offset + length].view(shape)) + offset += length + return tuple(params) + + return _unflatten_params + + +def _influence_batch_intermediate_quantities_influence_function( + influence_inst: "IntermediateQuantitiesInfluenceFunction", + test_batch: Tuple[Any, ...], + train_batch: Tuple[Any, ...], +): + """ + computes influence of a test batch on a train batch, for implementations of + `IntermediateQuantitiesInfluenceFunction` + """ + return torch.matmul( + influence_inst.compute_intermediate_quantities(test_batch), + influence_inst.compute_intermediate_quantities(train_batch).T, + ) + + +def _influence_helper_intermediate_quantities_influence_function( + influence_inst: "IntermediateQuantitiesInfluenceFunction", + inputs_dataset: Union[Tuple[Any, ...], DataLoader], + show_progress: bool, +): + """ + Helper function that computes influence scores for implementations of + `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities` + method returning "embedding" vectors, so that the influence score of one example + on another is the dot-product of their vectors. + """ + # If `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + inputs_intermediate_quantities = influence_inst.compute_intermediate_quantities( + inputs_dataset, + show_progress=show_progress, + test=True, + ) + + train_dataloader = influence_inst.train_dataloader + if show_progress: + train_dataloader = _progress_bar_constructor( + influence_inst, train_dataloader, "train_dataset", "influence scores" + ) + + return torch.cat( + [ + torch.matmul( + inputs_intermediate_quantities, + influence_inst.compute_intermediate_quantities(batch).T, + ) + for batch in train_dataloader + ], + dim=1, + ) + + +def _self_influence_helper_intermediate_quantities_influence_function( + influence_inst: "IntermediateQuantitiesInfluenceFunction", + inputs_dataset: Optional[Union[Tuple[Any, ...], DataLoader]], + show_progress: bool, +): + """ + Helper function that computes self-influence scores for implementations of + `NaiveInfluenceFunction` which implement the `compute_intermediate_quantities` + method returning "embedding" vectors, so that the self-influence score of an + example is the squared norm of its vector. + """ + + inputs_dataset = ( + inputs_dataset + if inputs_dataset is not None + else influence_inst.train_dataloader + ) + + # If `inputs_dataset` is not a `DataLoader`, turn it into one. + inputs_dataset = _format_inputs_dataset(inputs_dataset) + + if show_progress: + inputs_dataset = _progress_bar_constructor( + influence_inst, inputs_dataset, "inputs_dataset", "self influence scores" + ) + + return torch.cat( + [ + torch.sum( + influence_inst.compute_intermediate_quantities( + batch, + show_progress=False, + ) + ** 2, + dim=1, + ) + for batch in inputs_dataset + ] + ) + + +def _eig_helper(H: Tensor): + """ + wrapper around `torch.linalg.eig` that sorts eigenvalues / eigenvectors by + ascending eigenvalues, like `torch.linalg.eigh`, and returns the real component + (since `H` is never complex, there should never be a complex component. however, + `torch.linalg.eig` always returns a complex tensor, which in this case would + actually have no complex component) + """ + version = _parse_version(torch.__version__) + if version < (1, 9): + ls, vs = torch.eig(H, eigenvectors=True) + ls = ls[:, 0] + else: + ls, vs = torch.linalg.eig(H) + ls, vs = ls.real, vs.real + + ls_argsort = torch.argsort(ls) + vs = vs[:, ls_argsort] + ls = ls[ls_argsort] + return ls, vs + + +def _top_eigen( + H: Tensor, k: Optional[int], hessian_reg: float, hessian_inverse_tol: float +) -> Tuple[Tensor, Tensor]: + """ + This is a wrapper around `torch.linalg.eig` that performs some pre / + post-processing to make it suitable for computing the low-rank + "square root" of a matrix, i.e. given square matrix H, find tall and + skinny L such that LL' approximates H. This function returns eigenvectors (as the + columns of a matrix Q) and corresponding eigenvectors (as diagonal entries in + a matrix V), and we can then let L=QV^{1/2}Q'. However, doing so requires the + eigenvalues in V to be positive. Thus, this function does pre-processing (adds + an entry to the diagonal of H) and post-processing (returns only the top-k + eigenvectors / eigenvalues where the eigenvalues are above a positive tolerance) + to encourage and guarantee, respectively, that the returned eigenvalues be + positive. The pre-processing shifts the eigenvalues up by a constant, and the + post-processing effectively replaces H with the most similar matrix (in terms of + Frobenius norm) whose eigenvalues are above the tolerance, see + https://nhigham.com/2021/01/26/what-is-the-nearest-positive-semidefinite-matrix/. + + Args: + H (Tensor): a 2D square Tensor for which the top eigenvectors / eigenvalues + will be computed. + k (int): how many eigenvectors / eigenvalues to return (before dropping pairs + whose eigenvalue is below the tolerance). + hessian_reg (float): We add an entry to the diagonal of `H` to encourage it to + be positive definite. This is that entry. + hessian_inverse_tol (float): To compute the "square root" of `H` using the top + eigenvectors / eigenvalues, the eigenvalues should be positive, and + furthermore if above a tolerance, the inversion will be more + numerically stable. Therefore, we only return eigenvectors / + eigenvalues where the eigenvalue is above a tolerance. This argument + specifies that tolerance. + + Returns: + (eigenvalues, eigenvectors) (tuple of tensors): Mimicking the output of + `torch.linalg.eigh`, `eigenvalues` is a 1D tensor of the top-k + eigenvalues of the regularized `H` that are additionally above + `hessian_inverse_tol`, and `eigenvectors` is a 2D tensor whose columns + contain the corresponding eigenvectors. The eigenvalues are in + ascending order. + """ + # add regularization to hopefully make H positive definite + H = H + (torch.eye(len(H)).to(device=H.device) * hessian_reg) + + # find eigvectors / eigvals of H + # ls are eigenvalues, in ascending order + # columns of vs are corresponding eigenvectors + ls, vs = _eig_helper(H) + + # despite adding regularization to the hessian, it may still not be positive + # definite. we can get rid of negative eigenvalues, but for numerical stability + # can get rid of eigenvalues below a tolerance + keep = ls > hessian_inverse_tol + + ls = ls[keep] + vs = vs[:, keep] + + # only keep the top `k` eigvals / eigvectors + if not (k is None): + ls = ls[-k:] + vs = vs[:, -k:] + + # `torch.linalg.eig` is not deterministic in that you can multiply an eigenvector + # by -1, and it is still an eigenvector. to make eigenvectors deterministic, + # we multiply an eigenvector according to some rule that flips if you multiply + # the eigenvector by -1. in this case, that rule is whether the sum of the + # entries of the eigenvector are > 0 + rule = torch.sum(vs, dim=0) > 0 # entries are 0/1 + rule_multiplier = (2 * rule) - 1 # entries are -1/1 + vs = vs * rule_multiplier.unsqueeze(0) + + return ls, vs + + +class KMostInfluentialResults(NamedTuple): + """ + This namedtuple stores the results of using the `influence` method. This method + is implemented by all subclasses of `TracInCPBase` to calculate + proponents / opponents. The `indices` field stores the indices of the + proponents / opponents for each example in the test batch. For example, if finding + opponents, `indices[i][j]` stores the index in the training data of the example + with the `j`-th highest influence score on the `i`-th example in the test batch. + Similarly, the `influence_scores` field stores the actual influence scores, so that + `influence_scores[i][j]` is the influence score of example `indices[i][j]` in the + training data on example `i` of the test batch. Please see `TracInCPBase.influence` + for more details. + """ + + indices: Tensor + influence_scores: Tensor + + +def _influence_route_to_helpers( + influence_instance: Union["TracInCPBase", "InfluenceFunctionBase"], + inputs: Union[Tuple[Any, ...], DataLoader], + k: Optional[int] = None, + proponents: bool = True, + **kwargs, +) -> Union[Tensor, KMostInfluentialResults]: + """ + This is a helper function called by `TracInCPBase` and `InfluenceFunctionBase` + implementations. Those methods share a common logic in that they assume + an instance of their respective classes implement 2 private methods + (``_influence`, `_get_k_most_influential`), and the logic of + which private method to call is common, as described in the documentation of the + `influence` method. The arguments and return values of this function are the exact + same as the `influence` method. Note that `influence_instance` refers to the + instance for which the `influence` method was called. + """ + if k is None: + return influence_instance._influence(inputs, **kwargs) + else: + return influence_instance._get_k_most_influential( + inputs, + k, + proponents, + **kwargs, + ) + + +def _parameter_dot( + params_1: Tuple[Tensor, ...], params_2: Tuple[Tensor, ...] +) -> Tensor: + """ + returns the dot-product of 2 tensors, represented as tuple of tensors. + """ + return torch.tensor( + sum( + torch.sum(param_1 * param_2) + for (param_1, param_2) in zip(params_1, params_2) + ) + ) + + +def _parameter_add( + params_1: Tuple[Tensor, ...], params_2: Tuple[Tensor, ...] +) -> Tuple[Tensor, ...]: + """ + returns the sum of 2 tensors, represented as tuple of tensors. + """ + return tuple(param_1 + param_2 for (param_1, param_2) in zip(params_1, params_2)) + + +def _parameter_multiply(params: Tuple[Tensor, ...], c: Tensor) -> Tuple[Tensor, ...]: + """ + multiplies all tensors in a tuple of tensors by a given scalar + """ + return tuple(param * c for param in params) + + +def _parameter_to(params: Tuple[Tensor, ...], **to_kwargs) -> Tuple[Tensor, ...]: + """ + applies the `to` method to all tensors in a tuple of tensors + """ + return tuple(param.to(**to_kwargs) for param in params) + + +def _parameter_linear_combination( + paramss: List[Tuple[Tensor, ...]], cs: Tensor +) -> Tuple[Tensor, ...]: + """ + scales each parameter (tensor of tuples) in a list by the corresponding scalar in a + 1D tensor of the same length, and sums up the scaled parameters + """ + assert len(cs.shape) == 1 + result = _parameter_multiply(paramss[0], cs[0]) + for (params, c) in zip(paramss[1:], cs[1:]): + result = _parameter_add(result, _parameter_multiply(params, c)) + return result + + +def _compute_jacobian_sample_wise_grads_per_batch( + influence_inst: Union["TracInCP", "InfluenceFunctionBase"], + inputs: Tuple[Any, ...], + targets: Optional[Tensor] = None, + loss_fn: Optional[Union[Module, Callable]] = None, + reduction_type: Optional[str] = "none", +) -> Tuple[Tensor, ...]: + """ + `TracInCP`, `InfluenceFunction`, and `ArnoldiInfluenceFunction` all compute + jacobians, depending on their `sample_wise_grads_per_batch` attribute. this helper + wraps that logic. + """ + + if influence_inst.sample_wise_grads_per_batch: + return _compute_jacobian_wrt_params_with_sample_wise_trick( + influence_inst.model, + inputs, + targets, + loss_fn, + reduction_type, + influence_inst.layer_modules, + ) + return _compute_jacobian_wrt_params( + influence_inst.model, + inputs, + targets, + loss_fn, + influence_inst.layer_modules, + ) + + +def _compute_batch_loss_influence_function_base( + loss_fn: Optional[Union[Module, Callable]], + input: Any, + target: Any, + reduction_type: str, +): + """ + In implementations of `InfluenceFunctionBase`, we need to compute the total loss + for a batch given `loss_fn`, whose reduction can either be 'none', 'sum', or + 'mean', and whose output requires different scaling based on the reduction. This + helper houses that common logic, and returns the total loss for a batch given the + predictions (`inputs`) and labels (`targets`) for it. We compute the total loss + in order to compute the Hessian. + """ + if loss_fn is not None: + _loss = loss_fn(input, target) + else: + # following convention of `_compute_jacobian_wrt_params`, is no loss function is + # provided, the quantity backpropped is the output of the forward function. + assert reduction_type == "none" + _loss = input + + if reduction_type == "none": + # if loss_fn is a "reduction='none'" loss function, need to sum + # up the per-example losses. + return torch.sum(_loss) + elif reduction_type == "mean": + # in this case, we want the total loss for the batch, and should + # multiply the mean loss for the batch by the batch size. however, + # we can only infer the batch size if `_output` is a Tensor, and + # we assume the 0-th dimension to be the batch dimension. + if isinstance(input, Tensor): + multiplier = input.shape[0] + else: + multiplier = 1 + msg = ( + "`loss_fn` was inferred to behave as a `reduction='mean'` " + "loss function. however, the batch size of batches could not " + "be inferred. therefore, the total loss of a batch, which is " + "needed to compute the Hessian, is approximated as the output " + "of `loss_fn` for the batch. if this approximation is not " + "accurate, please change `loss_fn` to behave as a " + "`reduction='sum'` loss function, or a `reduction='none'` " + "and set `sample_grads_per_batch` to false." + ) + warnings.warn(msg) + return _loss * multiplier + elif reduction_type == "sum": + return _loss + else: + # currently, only support `reduction_type` to be + # 'none', 'sum', or 'mean' for + # `InfluenceFunctionBase` implementations + raise Exception + + +def _set_attr(obj, names, val): + if len(names) == 1: + setattr(obj, names[0], val) + else: + _set_attr(getattr(obj, names[0]), names[1:], val) + + +def _del_attr(obj, names): + if len(names) == 1: + delattr(obj, names[0]) + else: + _del_attr(getattr(obj, names[0]), names[1:]) + + +def _model_make_functional(model, param_names, params): + params = tuple([param.detach().requires_grad_() for param in params]) + + for param_name in param_names: + _del_attr(model, param_name.split(".")) + + return params + + +def _model_reinsert_params(model, param_names, params, register=False): + for (param_name, param) in zip(param_names, params): + _set_attr( + model, + param_name.split("."), + torch.nn.Parameter(param) if register else param, + ) + + +def _custom_functional_call(model, d, features): + param_names, params = zip(*list(d.items())) + _params = _model_make_functional(model, param_names, params) + _model_reinsert_params(model, param_names, params) + out = model(*features) + _model_reinsert_params(model, param_names, _params, register=True) + return out + + +def _functional_call(model, d, features): + """ + Makes a call to `model.forward`, which is treated as a function of the parameters + in `d`, a dict from parameter name to parameter, instead of as a function of + `features`, the argument that is unpacked to `model.forward` (i.e. + `model.forward(*features)`). Depending on what version of PyTorch is available, + we either use our own implementation, or directly use `torch.nn.utils.stateless` + or `torch.func.functional_call`. Put another way, this function mimics the latter + two implementations, using our own when the PyTorch version is too old. + """ + import torch + + version = _parse_version(torch.__version__) + if version < (1, 12, 0): + return _custom_functional_call(model, d, features) + elif version >= (1, 12, 0) and version < (2, 0, 0): + import torch.nn.utils.stateless + + return torch.nn.utils.stateless.functional_call(model, d, features) + else: + import torch.func + + return torch.func.functional_call(model, d, features) + + +def _dataset_fn(dataloader, batch_fn, reduce_fn, *batch_fn_args, **batch_fn_kwargs): + """ + Applies `batch_fn` to each batch in `dataloader`, reducing the results using + `reduce_fn`. This is useful for computing Hessians and Hessian-vector + products over an entire dataloader, and is used by both `NaiveInfluenceFunction` + and `ArnoldiInfluenceFunction`. + """ + _dataloader = iter(dataloader) + + def _reduce_fn(_result, _batch): + return reduce_fn(_result, batch_fn(_batch, *batch_fn_args, **batch_fn_kwargs)) + + result = batch_fn(next(_dataloader), *batch_fn_args, **batch_fn_kwargs) + return reduce(_reduce_fn, _dataloader, result) diff --git a/tests/influence/_core/test_arnoldi_influence.py b/tests/influence/_core/test_arnoldi_influence.py new file mode 100644 index 0000000000..d0667b2ca6 --- /dev/null +++ b/tests/influence/_core/test_arnoldi_influence.py @@ -0,0 +1,509 @@ +import tempfile +from typing import Callable, List, Tuple, Union + +import torch + +import torch.nn as nn +from captum.influence._core.arnoldi_influence_function import ( + _parameter_arnoldi, + _parameter_distill, + ArnoldiInfluenceFunction, +) +from captum.influence._core.influence_function import NaiveInfluenceFunction +from captum.influence._utils.common import ( + _eig_helper, + _flatten_params, + _top_eigen, + _unflatten_params_factory, +) +from parameterized import parameterized +from tests.helpers.basic import assertTensorAlmostEqual, BaseTest +from tests.influence._utils.common import ( + _format_batch_into_tuple, + build_test_name_func, + DataInfluenceConstructor, + ExplicitDataset, + generate_assymetric_matrix_given_eigenvalues, + generate_symmetric_matrix_given_eigenvalues, + get_random_model_and_data, + UnpackDataset, + USE_GPU_LIST, +) +from torch import Tensor +from torch.utils.data import DataLoader + + +class TestArnoldiInfluence(BaseTest): + @parameterized.expand( + [ + (dim, rank) + for (dim, rank) in [ + (5, 2), + (10, 5), + (20, 15), + ] + ], + name_func=build_test_name_func(), + ) + def test_top_eigen(self, dim: int, rank: int): + # generate symmetric matrix of specific rank and check can recover it using + # the eigenvalues / eigenvectors returned by `_top_eigen` + R = torch.randn(dim, rank) + H = torch.matmul(R, R.T) + ls, vs = _top_eigen(H, rank, 1e-5, 1e-5) + assertTensorAlmostEqual(self, vs @ torch.diag(ls) @ vs.T, H, 1e-2, "max") + + @parameterized.expand( + [ + (symmetric, eigenvalues, k, arnoldi_dim, params_shapes) + for symmetric in [True, False] + for (eigenvalues, k, arnoldi_dim, params_shapes, test_name) in [ + ( + 10 ** torch.linspace(-2, 2, 100), + 10, + 50, + [(4, 10), (15, 3), (3, 5)], + "standard", + ), + ] + ], + name_func=build_test_name_func(args_to_skip=["eigenvalues", "params_shapes"]), + ) + def test_parameter_arnoldi( + self, + symmetric: bool, + eigenvalues: Tensor, + k: int, + arnoldi_dim: int, + params_shapes: List[Tuple], + ): + """ + This performs the tests of https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/arnoldi_test.py#L96 # noqa: E501 + See `_test_parameter_arnoldi_and_distill` documentation for 'arnoldi' + mode for details. + """ + self._test_parameter_arnoldi_and_distill( + "arnoldi", symmetric, eigenvalues, k, arnoldi_dim, params_shapes + ) + + @parameterized.expand( + [ + (symmetric, eigenvalues, k, arnoldi_dim, params_shapes) + for symmetric in [True, False] + for (eigenvalues, k, arnoldi_dim, params_shapes, test_name) in [ + ( + 10 ** torch.linspace(-2, 2, 100), + 10, + 50, + [(4, 10), (15, 3), (3, 5)], + "standard", + ), + ] + ], + name_func=build_test_name_func(args_to_skip=["eigenvalues", "params_shapes"]), + ) + def test_parameter_distill( + self, + symmetric: bool, + eigenvalues: Tensor, + k: int, + arnoldi_dim: int, + params_shapes: List[Tuple], + ): + """ + This performs the tests of https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/arnoldi_test.py#L116 # noqa: E501 + See `_test_parameter_arnoldi_and_distill` documentation for + 'distill' mode for details. + """ + self._test_parameter_arnoldi_and_distill( + "distill", symmetric, eigenvalues, k, arnoldi_dim, params_shapes + ) + + def _test_parameter_arnoldi_and_distill( + self, + mode: str, + symmetric: bool, + eigenvalues: Tensor, + k: int, + arnoldi_dim: int, + param_shape: List[Tuple], + ): + """ + This is a helper with 2 modes. For both modes, it first generates a matrix + with `A` with specified eigenvalues. + + When mode is 'arnoldi', it checks that `_parameter_arnoldi` is correct. + In particular, it checks that the top-`k` eigenvalues of the restriction + of `A` to a Krylov subspace (the `H` returned by `_parameter_arnoldi`) + agree with those of the original matrix. This is a property we expect of the + Arnoldi iteration that `_parameter_arnoldi` implements. + + When mode is 'distill', it checks that `_parameter_distill` is correct. In + particular, it checks that the eigenvectors corresponding to the top + eigenvalues it returns agree with the top eigenvectors of `A`. This is the + property we require of `distill`, because we use the top eigenvectors (and + eigenvalues) of (implicitly-defined) `A` to calculate a low-rank approximation + of its inverse. + """ + # generate matrix `A` with specified eigenvalues + A = ( + generate_symmetric_matrix_given_eigenvalues(eigenvalues) + if symmetric + else generate_assymetric_matrix_given_eigenvalues(eigenvalues) + ) + + # create the matrix-vector multiplication function that `_parameter_arnoldi` + # expects that represents multiplication by `A`. + # since the vector actually needs to be a tuple of tensors, we + # specify the dimensions of that tuple of tensors. the function then + # flattens the vector, multiplies it by the generated matrix, and then + # unflattens the result + _unflatten_params = _unflatten_params_factory(param_shape) + + def _param_matmul(params: Tuple[Tensor]): + return _unflatten_params(torch.matmul(A, _flatten_params(params))) + + # generate `b` and call `_parameter_arnoldi` + b = tuple(torch.randn(shape) for shape in param_shape) + qs, H = _parameter_arnoldi( + _param_matmul, + b, + arnoldi_dim, + 1e-3, + torch.device("cpu"), + False, + ) + + assertTensorAlmostEqual( + self, + _flatten_params(_unflatten_params(_flatten_params(b))), + _flatten_params(b), + 1e-5, + "max", + ) + + # compute the eigenvalues / eigenvectors of `A` and `H`. we use `eig` since + # each matrix may not be symmetric. since `eig` does not sort by eigenvalues, + # need to manually do it. also get rid of last column of H, since + # it is not part of the decomposition + vs_A, ls_A = _eig_helper(A) + vs_H, ls_H = _eig_helper(H[:-1]) + + if mode == "arnoldi": + # compare the top-`k` eigenvalue of the two matrices + assertTensorAlmostEqual(self, vs_H[-k:], vs_A[-k:], 1e-3, "max") + elif mode == "distill": + # use `distill` to compute top-`k` eigenvectors of `H` in the original + # basis. then check if they are actually eigenvectors + vs_H_standard, ls_H_standard = _parameter_distill(qs, H, k, 0, 0) + + for (l_H_standard, v_A) in zip(ls_H_standard[-k:], vs_A[-k:]): + l_H_standard_flattened = _flatten_params(l_H_standard) # .real + expected = v_A * l_H_standard_flattened + actual = torch.matmul(A, l_H_standard_flattened) + # tol copied from original code + assert torch.norm(expected - actual) < 1e-2 + + # check that the top-`k` eigenvalues of `A` as computed by + # `_parameters_distill` are similar to those computed on `A` directly + for (v_H_standard, v_A) in zip(vs_H_standard[-k:], vs_A[-k:]): + # tol copied from original code + assert abs(v_H_standard - v_A) < 5 + + if False: + # code from original paper does not do this test, so skip for now + # use `distill`` to get top-`k` eigenvectors of `H` in the original + # basis, and compare with the top-`k` eigenvectors of `A`. need to + # flatten those from `distill` to compare + _, ls_H_standard = _parameter_distill(qs, H, k, 0, 0) + for (l_H_standard, l_A) in zip(ls_H_standard, ls_A): + # print(l_A) + # print(flatten_unflattener.flatten(l_H_standard).real) + l_H_standard_flattened /= torch.norm(l_H_standard_flattened) + assertTensorAlmostEqual( + self, + _flatten_params(l_H_standard).real, + l_A.real, + 1e-2, + "max", + ) + + @parameterized.expand( + [ + ( + influence_constructor_1, + influence_constructor_2, + delta, + mode, + unpack_inputs, + use_gpu, + ) + for use_gpu in USE_GPU_LIST + for (influence_constructor_1, influence_constructor_2, delta) in [ + # compare implementations, when considering only 1 layer + ( + DataInfluenceConstructor( + NaiveInfluenceFunction, + layers=["module.linear1"] + if use_gpu == "cuda_dataparallel" + else ["linear1"], + projection_dim=5, + show_progress=False, + name="NaiveInfluenceFunction_linear1", + ), + DataInfluenceConstructor( + ArnoldiInfluenceFunction, + layers=["module.linear1"] + if use_gpu == "cuda_dataparallel" + else ["linear1"], + arnoldi_dim=50, + arnoldi_tol=1e-5, # set low enough so that arnoldi subspace + # is large enough + projection_dim=5, + show_progress=False, + name="ArnoldiInfluenceFunction_linear1", + ), + 1e-2, + ), + # compare implementations, when considering all layers + ( + DataInfluenceConstructor( + NaiveInfluenceFunction, + layers=None, + projection_dim=5, + show_progress=False, + name="NaiveInfluenceFunction_all_layers", + ), + DataInfluenceConstructor( + ArnoldiInfluenceFunction, + layers=None, + arnoldi_dim=50, + arnoldi_tol=1e-5, # set low enough so that arnoldi subspace + # is large enough + projection_dim=5, + show_progress=False, + name="ArnoldiInfluenceFunction_all_layers", + ), + 1e-2, + ), + ] + for mode in [ + # we skip the 'intermediate_quantities' mode, as + # `NaiveInfluenceFunction` and `ArnoldiInfluenceFunction` return + # intermediate quantities lying in different coordinate systems, + # which cannot be expected to be the same. + "self_influence", + "influence", + ] + for unpack_inputs in [ + False, + True, + ] + ], + name_func=build_test_name_func(), + ) + def test_compare_implementations_trained_NN_model_and_data( + self, + influence_constructor_1: Callable, + influence_constructor_2: Callable, + delta: float, + mode: str, + unpack_inputs: bool, + use_gpu: Union[bool, str], + ): + """ + this compares 2 influence implementations on a trained 2-layer NN model. + the implementations we compare are `NaiveInfluenceFunction` and + `ArnoldiInfluenceFunction`. because the model is trained, calculations + are more numerically stable, so that we can project to a higher dimension (5). + """ + self._test_compare_implementations( + "trained_NN", + influence_constructor_1, + influence_constructor_2, + delta, + mode, + unpack_inputs, + use_gpu, + ) + + # this compares `ArnoldiInfluenceFunction` and `NaiveInfluenceFunction` on randomly + # generated data. because these implementations are numerically equivalent, we + # can also compare the intermediate quantities. we do not compare with + # `NaiveInfluence` because on randomly generated data, it is not comparable, + # conceptually, with the other implementations, due to numerical issues. + @parameterized.expand( + [ + ( + influence_constructor_1, + influence_constructor_2, + delta, + mode, + unpack_inputs, + use_gpu, + ) + for use_gpu in USE_GPU_LIST + for (influence_constructor_1, influence_constructor_2, delta) in [ + ( + DataInfluenceConstructor( + NaiveInfluenceFunction, + layers=["module.linear1"] + if use_gpu == "cuda_dataparallel" + else ["linear1"], + show_progress=False, + projection_dim=1, + ), + DataInfluenceConstructor( + ArnoldiInfluenceFunction, + layers=["module.linear1"] + if use_gpu == "cuda_dataparallel" + else ["linear1"], + show_progress=False, + arnoldi_dim=50, + arnoldi_tol=1e-6, + projection_dim=1, + ), + 1e-2, + ), + ] + for mode in [ + # we skip the 'intermediate_quantities' mode, as + # `NaiveInfluenceFunction` and `ArnoldiInfluenceFunction` return + # intermediate quantities lying in different coordinate systems, + # which cannot be expected to be the same. + "self_influence", + "influence", + ] + for unpack_inputs in [ + False, + True, + ] + ], + name_func=build_test_name_func(), + ) + def test_compare_implementations_random_model_and_data( + self, + influence_constructor_1: Callable, + influence_constructor_2: Callable, + delta: float, + mode: str, + unpack_inputs: bool, + use_gpu: Union[bool, str], + ): + """ + this compares 2 influence implementations on a trained 2-layer NN model. + the implementations we compare are `NaiveInfluenceFunction` and + `ArnoldiInfluenceFunction`. because the model is not trained, calculations + are not numerically stable, and so we can only project to a low dimension (2). + """ + self._test_compare_implementations( + "random", + influence_constructor_1, + influence_constructor_2, + delta, + mode, + unpack_inputs, + use_gpu, + ) + + def _test_compare_implementations( + self, + model_type: str, + influence_constructor_1: Callable, + influence_constructor_2: Callable, + delta: float, + mode: str, + unpack_inputs: bool, + use_gpu: Union[bool, str], + ) -> None: + """ + checks that 2 implementations of `InfluenceFunctionBase` return the same + output, where the output is either self influence scores, or influence scores, + as determined by the `mode` input. this is a helper used by other tests. the + implementations are compared using the same data, but the model and saved + checkpoints can be different, and is specified using the `model_type` argument. + """ + with tempfile.TemporaryDirectory() as tmpdir: + ( + net, + train_dataset, + hessian_samples, + hessian_labels, + test_samples, + test_labels, + ) = get_random_model_and_data( + tmpdir, + unpack_inputs, + return_test_data=True, + use_gpu=use_gpu, + return_hessian_data=True, + model_type=model_type, + ) + + train_dataset = DataLoader(train_dataset, batch_size=5) + + hessian_dataset = ( + ExplicitDataset(hessian_samples, hessian_labels, use_gpu) + if not unpack_inputs + else UnpackDataset(hessian_samples, hessian_labels, use_gpu) + ) + hessian_dataset = DataLoader(hessian_dataset, batch_size=5) + + criterion = nn.MSELoss(reduction="none") + batch_size = None + + influence_1 = influence_constructor_1( + net, + train_dataset, + tmpdir, + batch_size, + criterion, + hessian_dataset=hessian_dataset, + ) + + influence_2 = influence_constructor_2( + net, + train_dataset, + tmpdir, + batch_size, + criterion, + hessian_dataset=hessian_dataset, + ) + + if mode == "self_influence": + # compare self influence scores + assertTensorAlmostEqual( + self, + influence_1.self_influence(train_dataset), + influence_2.self_influence(train_dataset), + delta=delta, + mode="sum", + ) + elif mode == "intermediate_quantities": + # compare intermediate quantities + assertTensorAlmostEqual( + self, + influence_1.compute_intermediate_quantities(train_dataset), + influence_2.compute_intermediate_quantities(train_dataset), + delta=delta, + mode="max", + ) + elif mode == "influence": + # compare influence scores + assertTensorAlmostEqual( + self, + influence_1.influence( + _format_batch_into_tuple( + test_samples, test_labels, unpack_inputs + ) + ), + influence_2.influence( + _format_batch_into_tuple( + test_samples, test_labels, unpack_inputs + ) + ), + delta=delta, + mode="max", + ) + else: + raise Exception("unknown test mode") diff --git a/tests/influence/_core/test_naive_influence.py b/tests/influence/_core/test_naive_influence.py new file mode 100644 index 0000000000..e142fb0761 --- /dev/null +++ b/tests/influence/_core/test_naive_influence.py @@ -0,0 +1,281 @@ +import tempfile +from typing import Callable, List, Tuple, Union + +import torch + +import torch.nn as nn +from captum._utils.common import _parse_version +from captum.influence._core.influence_function import NaiveInfluenceFunction +from captum.influence._utils.common import ( + _custom_functional_call, + _flatten_params, + _functional_call, + _unflatten_params_factory, +) +from parameterized import parameterized +from tests.helpers.basic import ( + assertTensorAlmostEqual, + assertTensorTuplesAlmostEqual, + BaseTest, +) +from tests.influence._utils.common import ( + _format_batch_into_tuple, + build_test_name_func, + DataInfluenceConstructor, + ExplicitDataset, + get_random_model_and_data, + Linear, + UnpackDataset, + USE_GPU_LIST, +) +from torch.utils.data import DataLoader + + +class TestNaiveInfluence(BaseTest): + @parameterized.expand( + [ + (param_shape,) + for param_shape in [ + [(2, 3), (4, 5)], + [(3, 2), (4, 2), (1, 5)], + ] + ], + name_func=build_test_name_func(), + ) + def test_flatten_unflattener(self, param_shapes: List[Tuple[int, ...]]): + # unflatten and flatten should be inverses of each other. check this holds. + _unflatten_params = _unflatten_params_factory(param_shapes) + params = tuple(torch.randn(shape) for shape in param_shapes) + assertTensorTuplesAlmostEqual( + self, + params, + _unflatten_params(_flatten_params(params)), + delta=1e-4, + mode="max", + ) + + @parameterized.expand( + [ + ( + reduction, + influence_constructor, + delta, + mode, + unpack_inputs, + use_gpu, + ) + for reduction in ["none", "sum", "mean"] + for use_gpu in USE_GPU_LIST + for (influence_constructor, delta) in [ + ( + DataInfluenceConstructor( + NaiveInfluenceFunction, + layers=["module.linear"] + if use_gpu == "cuda_dataparallel" + else ["linear"], + projection_dim=None, + # letting projection_dim is None means no projection is done, + # in which case exact influence is returned + show_progress=False, + ), + 1e-3, + ), + ( + DataInfluenceConstructor( + NaiveInfluenceFunction, + layers=None, + # this tests that not specifyiing layers still works + projection_dim=None, + show_progress=False, + name="NaiveInfluenceFunction_all_layers", + ), + 1e-3, + ), + ] + for mode in [ + "influence", + "self_influence", + ] + for unpack_inputs in [ + False, + True, + ] + ], + name_func=build_test_name_func(), + ) + def test_matches_linear_regression( + self, + reduction: str, + influence_constructor: Callable, + delta: float, + mode: str, + unpack_inputs: bool, + use_gpu: Union[bool, str], + ): + """ + this tests that `NaiveInfluence`, the simplest implementation, agree with the + analytically calculated solution for influence and self-influence for a model + where we can calculate that solution - linear regression trained with squared + error loss. + """ + with tempfile.TemporaryDirectory() as tmpdir: + ( + net, + train_dataset, + hessian_samples, + hessian_labels, + test_samples, + test_labels, + ) = get_random_model_and_data( + tmpdir, + unpack_inputs, + return_test_data=True, + use_gpu=use_gpu, + return_hessian_data=True, + model_type="trained_linear", + ) + + train_dataset = DataLoader(train_dataset, batch_size=5) + + hessian_dataset = ( + ExplicitDataset(hessian_samples, hessian_labels, use_gpu) + if not unpack_inputs + else UnpackDataset(hessian_samples, hessian_labels, use_gpu) + ) + hessian_dataset = DataLoader(hessian_dataset, batch_size=5) + + criterion = nn.MSELoss(reduction=reduction) + batch_size = None + + # set `sample_grads_per_batch` based on `reduction` to be compatible + sample_wise_grads_per_batch = False if reduction == "none" else True + + influence = influence_constructor( + net, + train_dataset, + tmpdir, + batch_size, + criterion, + sample_wise_grads_per_batch=sample_wise_grads_per_batch, + hessian_dataset=hessian_dataset, + ) + + # since the model is a linear regression model trained with MSE loss, we + # can calculate the hessian and per-example parameter gradients + # analytically + tensor_hessian_samples = ( + hessian_samples + if not unpack_inputs + else torch.cat(hessian_samples, dim=1) + ) + # hessian at optimal parameters is 2 * X'X, where X is the feature matrix + # of the examples used for calculating the hessian. + # this is based on https://math.stackexchange.com/questions/2864585/hessian-on-linear-least-squares-problem # noqa: E501 + # and multiplying by 2, since we optimize squared error, + # not 1/2 squared error. + hessian = torch.matmul(tensor_hessian_samples.T, tensor_hessian_samples) * 2 + hessian = hessian + ( + torch.eye(len(hessian)).to(device=hessian.device) * 1e-4 + ) + version = _parse_version(torch.__version__) + if version < (1, 8): + hessian_inverse = torch.pinverse(hessian, rcond=1e-4) + else: + hessian_inverse = torch.linalg.pinv(hessian, rcond=1e-4) + + # gradient for an example is 2 * features * error + + # compute train gradients + tensor_train_samples = torch.cat( + [torch.cat(batch[:-1], dim=1) for batch in train_dataset], dim=0 + ) + train_predictions = torch.cat( + [net(*batch[:-1]) for batch in train_dataset], dim=0 + ) + train_labels = torch.cat([batch[-1] for batch in train_dataset], dim=0) + train_gradients = ( + (train_predictions - train_labels) * tensor_train_samples * 2 + ) + + # compute test gradients + tensor_test_samples = ( + test_samples if not unpack_inputs else torch.cat(test_samples, dim=1) + ) + test_predictions = ( + net(test_samples) if not unpack_inputs else net(*test_samples) + ) + test_gradients = (test_predictions - test_labels) * tensor_test_samples * 2 + + if mode == "influence": + # compute pairwise influences, analytically + analytical_train_test_influences = torch.matmul( + torch.matmul(test_gradients, hessian_inverse), train_gradients.T + ) + # compute pairwise influences using influence implementation + influence_train_test_influences = influence.influence( + _format_batch_into_tuple(test_samples, test_labels, unpack_inputs) + ) + # check error + assertTensorAlmostEqual( + self, + influence_train_test_influences, + analytical_train_test_influences, + delta=delta, + mode="max", + ) + elif mode == "self_influence": + # compute self influence, analytically + analytical_self_influences = torch.diag( + torch.matmul( + torch.matmul(train_gradients, hessian_inverse), + train_gradients.T, + ) + ) + # compute pairwise influences using influence implementation + influence_self_influences = influence.self_influence(train_dataset) + # check error + assertTensorAlmostEqual( + self, + influence_self_influences, + analytical_self_influences, + delta=delta, + mode="max", + ) + else: + raise Exception("unknown test mode") + + @parameterized.expand( + [(_custom_functional_call,), (_functional_call,)], + name_func=build_test_name_func(), + ) + def test_functional_call(self, method): + """ + tests `influence._utils.common._functional_call` for a simple case where the + model and loss are linear regression and squared error. `method` can either be + `_custom_functional_call`, which uses the custom implementation that is used + if pytorch does not provide one, or `_functional_call`, which uses a pytorch + implementation if available. + """ + # get linear model and a batch + batch_size = 25 + num_features = 5 + batch_samples = torch.normal(0, 1, (batch_size, num_features)) + batch_labels = torch.normal(0, 1, (batch_size, 1)) + net = Linear(num_features) + + # get the analytical gradient wrt to model parameters + batch_predictions = net(batch_samples) + analytical_grad = 2 * torch.sum( + (batch_predictions - batch_labels) * batch_samples, dim=0 + ) + + # get gradient as computed using `_functional_call` + param = net.linear.weight.detach().clone().requires_grad_(True) + _batch_predictions = method(net, {"linear.weight": param}, (batch_samples,)) + loss = torch.sum((_batch_predictions - batch_labels) ** 2) + actual_grad = torch.autograd.grad(loss, param)[0][0] + + # they should be the same + assertTensorAlmostEqual( + self, actual_grad, analytical_grad, delta=1e-3, mode="max" + ) diff --git a/tests/influence/_core/test_tracin_intermediate_quantities.py b/tests/influence/_core/test_tracin_intermediate_quantities.py index b65cd0225c..0919190d96 100644 --- a/tests/influence/_core/test_tracin_intermediate_quantities.py +++ b/tests/influence/_core/test_tracin_intermediate_quantities.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn +from captum.influence._core.arnoldi_influence_function import ArnoldiInfluenceFunction +from captum.influence._core.influence_function import NaiveInfluenceFunction from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import ( TracInCPFast, @@ -27,6 +29,8 @@ class TestTracInIntermediateQuantities(BaseTest): for unpack_inputs in [True, False] for (reduction, constructor) in [ ("none", DataInfluenceConstructor(TracInCP)), + ("none", DataInfluenceConstructor(NaiveInfluenceFunction)), + ("none", DataInfluenceConstructor(ArnoldiInfluenceFunction)), ] ], name_func=build_test_name_func(), @@ -83,6 +87,7 @@ def test_tracin_intermediate_quantities_aggregate( for (reduction, constructor) in [ ("sum", DataInfluenceConstructor(TracInCPFastRandProj)), ("none", DataInfluenceConstructor(TracInCP)), + ("none", DataInfluenceConstructor(NaiveInfluenceFunction)), ] ], name_func=build_test_name_func(), @@ -166,6 +171,11 @@ def test_tracin_intermediate_quantities_api( DataInfluenceConstructor(TracInCP), DataInfluenceConstructor(TracInCP), ), + ( + "none", + DataInfluenceConstructor(NaiveInfluenceFunction), + DataInfluenceConstructor(NaiveInfluenceFunction), + ), ] ], name_func=build_test_name_func(), @@ -190,7 +200,9 @@ def test_tracin_intermediate_quantities_consistent( methods for the 2 cases are different, we need to parametrize the test with 2 different tracin constructors. `tracin_constructor` is the constructor for the tracin implementation for case 1. `intermediate_quantities_tracin_constructor` - is the constructor for the tracin implementation for case 2. + is the constructor for the tracin implementation for case 2. Note that we also + use this test for implementations of `InfluenceFunctionBase`, where for the + same method, both ways should give the same result by definition. """ with tempfile.TemporaryDirectory() as tmpdir: ( diff --git a/tests/influence/_core/test_tracin_regression.py b/tests/influence/_core/test_tracin_regression.py index 27b6ec9f5d..adcd2cd853 100644 --- a/tests/influence/_core/test_tracin_regression.py +++ b/tests/influence/_core/test_tracin_regression.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn +from captum.influence._core.arnoldi_influence_function import ArnoldiInfluenceFunction +from captum.influence._core.influence_function import NaiveInfluenceFunction from captum.influence._core.tracincp import TracInCP from captum.influence._core.tracincp_fast_rand_proj import ( TracInCPFast, @@ -320,7 +322,7 @@ def _test_tracin_identity_regression_setup(self, tmpdir: str): num_checkpoints = 5 for i in range(num_checkpoints): - net.fc1.weight.data = torch.rand((1, num_features)) + net.fc1.weight.data = torch.rand((1, num_features)) * 100 checkpoint_name = "-".join(["checkpoint-reg", str(i) + ".pt"]) torch.save(net.state_dict(), os.path.join(tmpdir, checkpoint_name)) @@ -340,6 +342,15 @@ def _test_tracin_identity_regression_setup(self, tmpdir: str): ("check_idx", "sum", DataInfluenceConstructor(TracInCPFastRandProj)), ("check_idx", "mean", DataInfluenceConstructor(TracInCPFast)), ("check_idx", "mean", DataInfluenceConstructor(TracInCPFastRandProj)), + ("check_idx", "none", DataInfluenceConstructor(NaiveInfluenceFunction)), + ( + "check_idx", + "none", + DataInfluenceConstructor( + ArnoldiInfluenceFunction, + arnoldi_tol=1e-8, # needs to be small to avoid empty arnoldi basis + ), + ), ], name_func=build_test_name_func(), ) @@ -435,6 +446,14 @@ def test_tracin_identity_regression( ("mean", "mean", DataInfluenceConstructor(TracInCPFast)), ("sum", "sum", DataInfluenceConstructor(TracInCPFastRandProj)), ("mean", "mean", DataInfluenceConstructor(TracInCPFastRandProj)), + ("none", "none", DataInfluenceConstructor(NaiveInfluenceFunction)), + # ( + # "none", + # "none", + # DataInfluenceConstructor(ArnoldiInfluenceFunction, arnoldi_tol=1e-9), + # # need to set `arnoldi_tol` small. otherwise, arnoldi iteration + # # terminates early and get 'Arnoldi basis is empty' exception. + # ), ], name_func=build_test_name_func(), ) diff --git a/tests/influence/_core/test_tracin_self_influence.py b/tests/influence/_core/test_tracin_self_influence.py index 767aed6b02..8ab26c098f 100644 --- a/tests/influence/_core/test_tracin_self_influence.py +++ b/tests/influence/_core/test_tracin_self_influence.py @@ -3,7 +3,9 @@ import torch import torch.nn as nn -from captum.influence._core.tracincp import TracInCP +from captum.influence._core.arnoldi_influence_function import ArnoldiInfluenceFunction +from captum.influence._core.influence_function import NaiveInfluenceFunction +from captum.influence._core.tracincp import TracInCP, TracInCPBase from captum.influence._core.tracincp_fast_rand_proj import TracInCPFast from parameterized import parameterized from tests.helpers.basic import assertTensorAlmostEqual, BaseTest @@ -18,13 +20,18 @@ class TestTracInSelfInfluence(BaseTest): + param_list = [] + + # add the tests for `TracInCPBase` implementations and `InfluenceFunctionBase` + # implementations separately, because the latter does not support `DataParallel` + + # add tests for `TracInCPBase` implementations use_gpu_list = ( [False, "cuda", "cuda_data_parallel"] if torch.cuda.is_available() and torch.cuda.device_count() != 0 else [False] ) - param_list = [] for unpack_inputs in [True, False]: for use_gpu in use_gpu_list: for (reduction, constructor) in [ @@ -80,6 +87,57 @@ class TestTracInSelfInfluence(BaseTest): ): param_list.append((reduction, constructor, unpack_inputs, use_gpu)) + # add tests for `InfluenceFunctionBase` implementations + use_gpu_list = ( + [False, "cuda"] + if torch.cuda.is_available() and torch.cuda.device_count() != 0 + else [False] + ) + + for unpack_inputs in [True, False]: + for use_gpu in use_gpu_list: + for (reduction, constructor) in [ + ( + "none", + DataInfluenceConstructor( + NaiveInfluenceFunction, name="NaiveInfluenceFunction_all_layers" + ), + ), + ( + "none", + DataInfluenceConstructor( + NaiveInfluenceFunction, + name="NaiveInfluenceFunction_linear1", + layers=["module.linear1"] + if use_gpu == "cuda_data_parallel" + else ["linear1"], + ), + ), + ( + "none", + DataInfluenceConstructor( + ArnoldiInfluenceFunction, + name="ArnoldiInfluenceFunction_all_layers", + ), + ), + ( + "none", + DataInfluenceConstructor( + ArnoldiInfluenceFunction, + name="ArnoldiInfluenceFunction_linear1", + layers=["module.linear1"] + if use_gpu == "cuda_data_parallel" + else ["linear1"], + ), + ), + ]: + if not ( + "sample_wise_grads_per_batch" in constructor.kwargs + and constructor.kwargs["sample_wise_grads_per_batch"] + and use_gpu + ): + param_list.append((reduction, constructor, unpack_inputs, use_gpu)) + @parameterized.expand( param_list, name_func=build_test_name_func(), @@ -117,9 +175,7 @@ def test_tracin_self_influence( k=None, ) # calculate self_tracin_scores - self_tracin_scores = tracin.self_influence( - outer_loop_by_checkpoints=False, - ) + self_tracin_scores = tracin.self_influence() # check that self_tracin scores equals the diagonal of influence scores assertTensorAlmostEqual( @@ -132,17 +188,22 @@ def test_tracin_self_influence( # check that setting `outer_loop_by_checkpoints=False` and # `outer_loop_by_checkpoints=True` gives the same self influence scores - self_tracin_scores_by_checkpoints = tracin.self_influence( - DataLoader(train_dataset, batch_size=batch_size), - outer_loop_by_checkpoints=True, - ) - assertTensorAlmostEqual( - self, - self_tracin_scores_by_checkpoints, - self_tracin_scores, - delta=0.01, - mode="max", - ) + # this test is only relevant for implementations of `TracInCPBase`, as + # implementations of `InfluenceFunctionBase` do not use checkpoints. + if isinstance(tracin, TracInCPBase): + self_tracin_scores_by_checkpoints = ( + tracin.self_influence( # type: ignore + DataLoader(train_dataset, batch_size=batch_size), + outer_loop_by_checkpoints=True, + ) + ) + assertTensorAlmostEqual( + self, + self_tracin_scores_by_checkpoints, + self_tracin_scores, + delta=0.01, + mode="max", + ) @parameterized.expand( [ diff --git a/tests/influence/_utils/common.py b/tests/influence/_utils/common.py index 17fe5b46cb..e1aa886e16 100644 --- a/tests/influence/_utils/common.py +++ b/tests/influence/_utils/common.py @@ -2,12 +2,16 @@ import os import unittest from functools import partial -from typing import Callable, Iterator, List, Optional, Tuple, Union +from inspect import isfunction +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +from captum._utils.common import _parse_version from captum.influence import DataInfluence +from captum.influence._core.arnoldi_influence_function import ArnoldiInfluenceFunction +from captum.influence._core.influence_function import NaiveInfluenceFunction from captum.influence._core.tracincp_fast_rand_proj import ( TracInCPFast, TracInCPFastRandProj, @@ -180,10 +184,73 @@ def forward(self, *inputs): return torch.tanh(self.linear2(x)) +class Linear(nn.Module): + """ + a wrapper around `nn.Linear`, with purpose being to have an analogue to + `UnpackLinear`, with both's only parameter being 'linear'. "infinitesimal" + influence (i.e. that calculated by `InfluenceFunctionBase` implementations) for + this simple module can be analytically calculated, so its purpose is for testing + those implementations. + """ + + def __init__(self, in_features): + super().__init__() + self.linear = nn.Linear(in_features, 1, bias=False) + + def forward(self, input): + return self.linear(input) + + +class UnpackLinear(nn.Module): + """ + the analogue of `Linear` which unpacks inputs, serving the same purpose. + """ + + def __init__(self, in_features, num_inputs) -> None: + super().__init__() + self.linear = nn.Linear(in_features * num_inputs, 1, bias=False) + + def forward(self, *inputs): + return self.linear(torch.cat(inputs, dim=1)) + + def get_random_model_and_data( - tmpdir, unpack_inputs, return_test_data=True, use_gpu=False + tmpdir, + unpack_inputs, + return_test_data=True, + use_gpu=False, + return_hessian_data=False, + model_type="random", ): """ + returns a model, training data, and optionally data for computing the hessian + (needed for `InfluenceFunctionBase` implementations) as features / labels, and + optionally test data as features / labels. + + the data is always generated the same way. however depending on `model_type`, + a different model and checkpoints are returned. + - `model_type='random'`: the model is a 2-layer NN, and several checkpoints are + generated + - `model_type='trained_linear'`: the model is a linear model, and assumed to be + eventually trained to optimality. therefore, we find the optimal parameters, and + save a single checkpoint containing them. the training is done using the Hessian + data, because the purpose of training the model is so that the Hessian is positive + definite. since the Hessian is calculated using the Hessian data, it should be + used for training. since it is trained to optimality using the Hessian data, we can + guarantee that the Hessian is positive definite, so that different + implementations of `InfluenceFunctionBase` can be more easily compared. (if the + Hessian is not positive definite, we drop eigenvectors corresponding to negative + eigenvalues. since the eigenvectors dropped in `ArnoldiInfluence` differ from those + in `NaiveInfluenceFunction` due to the formers' use of Arnoldi iteration, we should + only use models / data whose Hessian is positive definite, so that no eigenvectors + are dropped). in short, this model / data are suitable for comparing different + `InfluenceFunctionBase` implementations. + - `model_type='trained_NN'`: the model is a 2-layer NN, and trained (not + necessarily) to optimality using the Hessian data. since it is trained, for same + reasons as for `model_type='trained_linear`, different implementations of + `InfluenceFunctionBase` can be more easily compared, due to lack of numerical + issues. + `use_gpu` can either be - `False`: returned model is on cpu - `'cuda'`: returned model is on gpu @@ -192,57 +259,54 @@ def get_random_model_and_data( is that sometimes we may want to test a model that is on cpu, but is *not* wrapped in `DataParallel`. """ - assert use_gpu in [False, "cuda", "cuda_data_parallel"] - - in_features, hidden_nodes, out_features = 5, 4, 3 + in_features, hidden_nodes = 5, 4 num_inputs = 2 - net = ( - BasicLinearNet(in_features, hidden_nodes, out_features) - if not unpack_inputs - else MultLinearNet(in_features, hidden_nodes, out_features, num_inputs) - ).double() - - num_checkpoints = 5 - - for i in range(num_checkpoints): - net.linear1.weight.data = torch.normal( - 3, 4, (hidden_nodes, in_features) - ).double() - net.linear2.weight.data = torch.normal( - 5, 6, (out_features, hidden_nodes) - ).double() - if unpack_inputs: - net.pre.weight.data = torch.normal( - 3, 4, (in_features, in_features * num_inputs) - ) - if hasattr(net, "pre"): - net.pre.weight.data = net.pre.weight.data.double() - checkpoint_name = "-".join(["checkpoint-reg", str(i + 1) + ".pt"]) - net_adjusted = ( - _wrap_model_in_dataparallel(net) - if use_gpu == "cuda_data_parallel" - else (net.to(device="cuda") if use_gpu == "cuda" else net) - ) - torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name)) + # generate data. regardless the model, the data is always generated the same way + # the only exception is if the `model_type` is 'trained_linear', i.e. a simple + # linear regression model. this is a simple model, and for simplicity, the + # number of `out_features` is 1 in this case. + if model_type == "trained_linear": + out_features = 1 + else: + out_features = 3 num_samples = 50 num_train = 32 + num_hessian = 22 # this needs to be high to prevent numerical issues all_labels = torch.normal(1, 2, (num_samples, out_features)).double() + all_labels = all_labels.cuda() if use_gpu else all_labels train_labels = all_labels[:num_train] test_labels = all_labels[num_train:] + hessian_labels = all_labels[:num_hessian] if unpack_inputs: all_samples = [ torch.normal(0, 1, (num_samples, in_features)).double() for _ in range(num_inputs) ] + all_samples = ( + _move_sample_to_cuda(all_samples) + if isinstance(all_samples, list) and use_gpu + else all_samples.cuda() + if use_gpu + else all_samples + ) train_samples = [ts[:num_train] for ts in all_samples] test_samples = [ts[num_train:] for ts in all_samples] + hessian_samples = [ts[:num_hessian] for ts in all_samples] else: all_samples = torch.normal(0, 1, (num_samples, in_features)).double() + all_samples = ( + _move_sample_to_cuda(all_samples) + if isinstance(all_samples, list) and use_gpu + else all_samples.cuda() + if use_gpu + else all_samples + ) train_samples = all_samples[:num_train] test_samples = all_samples[num_train:] + hessian_samples = all_samples[:num_hessian] dataset = ( ExplicitDataset(train_samples, train_labels, use_gpu) @@ -250,26 +314,191 @@ def get_random_model_and_data( else UnpackDataset(train_samples, train_labels, use_gpu) ) - if return_test_data: - return ( + if model_type == "random": + net = ( + BasicLinearNet(in_features, hidden_nodes, out_features) + if not unpack_inputs + else MultLinearNet(in_features, hidden_nodes, out_features, num_inputs) + ).double() + + # generate checkpoints randomly + num_checkpoints = 5 + + for i in range(num_checkpoints): + net.linear1.weight.data = torch.normal( + 3, 4, (hidden_nodes, in_features) + ).double() + net.linear2.weight.data = torch.normal( + 5, 6, (out_features, hidden_nodes) + ).double() + if unpack_inputs: + net.pre.weight.data = torch.normal( + 3, 4, (in_features, in_features * num_inputs) + ).double() + checkpoint_name = "-".join(["checkpoint-reg", str(i + 1) + ".pt"]) + net_adjusted = ( + _wrap_model_in_dataparallel(net) + if use_gpu == "cuda_data_parallel" + else (net.to(device="cuda") if use_gpu == "cuda" else net) + ) + torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name)) + + elif model_type == "trained_linear": + net = ( + Linear(in_features) + if not unpack_inputs + else UnpackLinear(in_features, num_inputs) + ).double() + + # regardless of `unpack_inputs`, the model is a linear regression, so that + # we can get the optimal trained parameters via least squares + + # turn input into a single tensor for use by least squares + tensor_hessian_samples = ( + hessian_samples if not unpack_inputs else torch.cat(hessian_samples, dim=1) + ) + version = _parse_version(torch.__version__) + if version < (1, 9): + theta = torch.lstsq(tensor_hessian_samples, hessian_labels).solution[0:1] + else: + # run least squares to get optimal trained parameters + theta = torch.linalg.lstsq( + hessian_labels, + tensor_hessian_samples, + ).solution + # the first `n` rows of `theta` contains the least squares solution, where + # `n` is the number of features in `tensor_hessian_samples` + theta = theta[: tensor_hessian_samples.shape[1]] + + # save that trained parameter as a checkpoint + checkpoint_name = "checkpoint-final.pt" + net.linear.weight.data = theta.contiguous() + net_adjusted = ( _wrap_model_in_dataparallel(net) if use_gpu == "cuda_data_parallel" - else (net.to(device="cuda") if use_gpu == "cuda" else net), - dataset, - _move_sample_to_cuda(test_samples) - if isinstance(test_samples, list) and use_gpu - else test_samples.cuda() - if use_gpu - else test_samples, - test_labels.cuda() if use_gpu else test_labels, + else (net.to(device="cuda") if use_gpu == "cuda" else net) ) - else: - return ( + torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name)) + + elif model_type == "trained_NN": + net = ( + BasicLinearNet(in_features, hidden_nodes, out_features) + if not unpack_inputs + else MultLinearNet(in_features, hidden_nodes, out_features, num_inputs) + ).double() + + net_adjusted = ( _wrap_model_in_dataparallel(net) if use_gpu == "cuda_data_parallel" - else (net.to(device="cuda") if use_gpu == "cuda" else net), - dataset, + else (net.to(device="cuda") if use_gpu == "cuda" else net) + ) + + # train model using several optimization steps on Hessian data + + # create entire Hessian data as a batch + hessian_dataset = ( + ExplicitDataset(hessian_samples, hessian_labels, use_gpu) + if not unpack_inputs + else UnpackDataset(hessian_samples, hessian_labels, use_gpu) + ) + batch = next(iter(DataLoader(hessian_dataset, batch_size=num_hessian))) + + optimizer = torch.optim.Adam(net.parameters()) + num_steps = 200 + criterion = nn.MSELoss(reduction="sum") + for _ in range(num_steps): + optimizer.zero_grad() + output = net_adjusted(*batch[:-1]) + loss = criterion(output, batch[-1]) + loss.backward() + optimizer.step() + + # save that trained parameter as a checkpoint + checkpoint_name = "checkpoint-final.pt" + net_adjusted = ( + _wrap_model_in_dataparallel(net) if use_gpu == "cuda_data_parallel" else net + ) + torch.save(net_adjusted.state_dict(), os.path.join(tmpdir, checkpoint_name)) + + training_data = ( + net_adjusted, + dataset, + ) + + hessian_data = ( + _move_sample_to_cuda(hessian_samples) + if isinstance(hessian_samples, list) and use_gpu + else hessian_samples.cuda() + if use_gpu + else hessian_samples, + hessian_labels.cuda() if use_gpu else hessian_labels, + ) + + test_data = ( + _move_sample_to_cuda(test_samples) + if isinstance(test_samples, list) and use_gpu + else test_samples.cuda() + if use_gpu + else test_samples, + test_labels.cuda() if use_gpu else test_labels, + ) + if return_test_data: + if not return_hessian_data: + return (*training_data, *test_data) + else: + return (*training_data, *hessian_data, *test_data) + else: + if not return_hessian_data: + return training_data + else: + return (*training_data, *hessian_data) + + +def generate_symmetric_matrix_given_eigenvalues( + eigenvalues: Union[Tensor, List[float]] +): + """ + following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L123 # noqa: E501 + generate symmetric random matrix with specified eigenvalues. this is used in + `TestArnoldiInfluence._test_parameter_arnoldi_and_distill` either to check that + `_parameter_arnoldi` does return the top eigenvalues of a symmetric random matrix, + or that `_parameter_distill` does return the eigenvectors corresponding to the top + eigenvalues of that symmetric random matrix. + """ + # generate random matrix, then apply gram-schmidt to get random orthonormal basis + D = len(eigenvalues) + version = _parse_version(torch.__version__) + if version < (1, 8): + Q, _ = torch.qr(torch.randn((D, D))) + else: + Q, _ = torch.linalg.qr(torch.randn((D, D))) + return torch.matmul(Q, torch.matmul(torch.diag(torch.tensor(eigenvalues)), Q.T)) + + +def generate_assymetric_matrix_given_eigenvalues( + eigenvalues: Union[Tensor, List[float]] +): + """ + following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501 + generate assymetric random matrix with specified eigenvalues. this is used in + `TestArnoldiInfluence._test_parameter_arnoldi_and_distill` either to check that + `_parameter_arnoldi` does return the top eigenvalues of a assymmetric random + matrix, or that `_parameter_distill` does return the eigenvectors corresponding to + the top eigenvalues of that assymmetric random matrix. + """ + # the matrix M, given eigenvectors Q and eigenvalues L, should satisfy MQ = QL + # or equivalently, Q'M' = LQ'. + D = len(eigenvalues) + Q_T = torch.randn((D, D)) + version = _parse_version(torch.__version__) + if version < (1, 8): + X, _ = torch.solve( + Q_T, torch.matmul(torch.diag(torch.tensor(eigenvalues)), Q_T) ) + return X.T + return torch.linalg.solve( + Q_T, torch.matmul(torch.diag(torch.tensor(eigenvalues)), Q_T) + ).T class DataInfluenceConstructor: @@ -295,11 +524,14 @@ def __init__( def __repr__(self) -> str: return self.name + def __name__(self) -> str: + return self.name + def __call__( self, net: Module, dataset: Union[Dataset, DataLoader], - tmpdir: Union[str, List[str], Iterator], + tmpdir: str, batch_size: Union[int, None], loss_fn: Optional[Union[Module, Callable]], **kwargs, @@ -324,6 +556,26 @@ def __call__( batch_size=batch_size, **constructor_kwargs, ) + elif self.data_influence_class in [ + NaiveInfluenceFunction, + ArnoldiInfluenceFunction, + ]: + # for these implementations, only a single checkpoint is needed, not + # a directory containing several checkpoints. therefore, given + # directory `tmpdir`, we do not pass it directly to the constructor, + # but instead find 1 checkpoint in it, and pass that to the + # constructor + checkpoint_name = sorted(os.listdir(tmpdir))[-1] + checkpoint = os.path.join(tmpdir, checkpoint_name) + + return self.data_influence_class( + net, + dataset, + checkpoint, + loss_fn=loss_fn, + batch_size=batch_size, + **constructor_kwargs, + ) else: return self.data_influence_class( net, @@ -371,6 +623,8 @@ def generate_test_name( if isinstance(arg, bool): if arg: param_strs.append(func_param_names[i]) + elif isfunction(arg): + param_strs.append(arg.__name__) else: args_str = str(arg) if args_str.isnumeric(): @@ -397,3 +651,10 @@ def _format_batch_into_tuple( return (*inputs, targets) else: return (inputs, targets) + + +USE_GPU_LIST = ( + [False, "cuda"] + if torch.cuda.is_available() and torch.cuda.device_count() != 0 + else [False] +)