You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #1072
This diff changes the API for implementations of `TracInCPBase` as discussed in https://fb.quip.com/JbpnAiWluZmI. In particular, the arguments representing test data of the `influence` method are changed from `inputs: Tuple, targets: Optional[Tensor]` to `inputs: Union[Tuple[Any], DataLoader]`, which is either a single batch, or a dataloader yielding batches. In both cases, `model(*batch)` is assumed to produce the predictions for a batch, and `batch[-1]` is assumed to be the labels for a batch. This is the same format assumed of the batches yielded by `train_dataloader`.
We make this change for 2 reasons
- it unifies the assumptions made of the test data and the assumptions made of the training data
- for some implementations, we want to allow the test data to be represented by a dataloader. with the old API, there was no clean way to allow both a single as well as a dataloader to be passed in, since a batch required 2 arguments, but a dataloader only requires 1.
For now, all implementations only allow `inputs` to be a tuple (and not a dataloader). This is okay due to inheritance rules. Later on, we will allow some implementations (i.e. `TracInCP`) to accept a dataloader as `inputs`.
Other changes:
- changes to make documentation. for example, documentation in `TracInCPBase.influence` now refers to the "test dataset" instead of test batch.
- the `unpack_inputs` argument is no longer needed for the `influence` methods, and is removed
- the usage of `influence` in all the tests is changed to match new API.
- signature of helper methods `_influence_batch_tracincp` and `_influence_batch_tracincp_fast` are changed to match new representation of batches.
Differential Revision: https://internalfb.com/D41324297
fbshipit-source-id: d962b940452685e7e488986f11c769633b3d3e2d
0 commit comments