-
Notifications
You must be signed in to change notification settings - Fork 530
Commit afa439f
let tracincp aggregate influence (#1088)
Summary:
This diff adds an "aggregate" option to `TracInCP.influence`. The "aggregate" influence score of a training example on a test dataset is the sum of the influence of the training example on all examples in the test dataset. When `aggregate` is True, `influence` in influence score mode returns a 2D tensor of shape (1, training dataset size) containing aggregate influence scores of all training examples. When `aggregate` is True, `influence` in k most influential mode returns a 2D tensor of shape (1, k) of proponents (or opponents), and a 2D tensor containing the corresponding aggregate influence scores, of the same shape.
This option is only added for `TracInCP`, because for it, aggregate influence can be computed more quickly than naively computing the influence score of all training examples on all test examples, and then summing across test examples. In particular, we can first sum the jacobians across all test examples, and then take the dot-product of the sum with the jacobians of training examples. (all this is done across checkpoints).
Since computing aggregate influence scores is efficient, even if the test dataset is large, we now allow `inputs` for `influence` to be a dataloader, so that it does not need to fit in memory.
One use case of aggregate influence is to compute the influence of a training example on some validation metric, i.e. fairness metric.
We add the following tests:
- in newly added `test_tracin_aggregate_influence`, `test_tracin_aggregate_influence` tests that calling `influence` with `aggregate=True`does give the same result as calling it with `aggregate=False`, and then summing.
- in newly added `test_tracin_aggregate_influence`, `test_tracin_aggregate_influence_api` tests that the result of calling `influence` when `aggregate` is true for a DataLoader of batches is the same as when the batches are collated into a single batch.
- in `test_tracin_k_most_influential`, we modify the test to allow `aggregate` to be true, which tests that the proponents computed with the memory saving approach by `influence` are the same proponents computed via calculating all aggregate influence scores, and then sorting (not memory efficient).ar
Reviewed By: cyrjano
Differential Revision: D418302451 parent ed3b1fa commit afa439fCopy full SHA for afa439f
File tree
Expand file treeCollapse file tree
4 files changed
+353
-83
lines changedFilter options
- captum/influence
- _core
- _utils
- tests/influence/_core
Expand file treeCollapse file tree
4 files changed
+353
-83
lines changed
0 commit comments