Skip to content

Commit 3963723

Browse files
cicichen01facebook-github-bot
authored andcommitted
Add test for influence utils (#1265)
Summary: Pull Request resolved: #1265 As titled. - Initial the test for influence utils. Reviewed By: yucu Differential Revision: D55375620 fbshipit-source-id: f9f9aee988b76c0ca64089c01b56aecb66e8f824
1 parent dab9447 commit 3963723

File tree

3 files changed

+65
-14
lines changed

3 files changed

+65
-14
lines changed

captum/influence/_utils/common.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _jacobian_loss_wrt_inputs(
108108
batch).
109109
110110
Args:
111-
loss_fn (torch.nn.Module, Callable, or None): The loss function. If a library
111+
loss_fn (torch.nn.Module, Callable): The loss function. If a library
112112
defined loss function is provided, it would be expected to be a
113113
torch.nn.Module. If a custom loss is provided, it can be either type,
114114
but must behave as a library loss function would if `reduction='sum'`
@@ -131,24 +131,21 @@ def _jacobian_loss_wrt_inputs(
131131
in the batch represented by `out`. This is a 2D tensor, where the
132132
first dimension is the batch dimension.
133133
"""
134-
# TODO: allow loss_fn to be Callable
135-
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
136-
msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`"
137-
138-
assert loss_fn.reduction != "none", msg0
139-
msg1 = (
140-
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
141-
f"reduction type ({reduction_type}). Please ensure they are"
142-
" matching."
143-
)
144-
assert loss_fn.reduction == reduction_type, msg1
145-
146134
if reduction_type != "sum" and reduction_type != "mean":
147135
raise ValueError(
148-
f"{reduction_type} is not a valid value for reduction_type. "
136+
f"`{reduction_type}` is not a valid value for reduction_type. "
149137
"Must be either 'sum' or 'mean'."
150138
)
151139

140+
# TODO: allow loss_fn to be Callable
141+
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
142+
msg = (
143+
f"loss_fn.reduction `{loss_fn.reduction}` does not match"
144+
f"reduction type `{reduction_type}`. Please ensure they are"
145+
" matching."
146+
)
147+
assert loss_fn.reduction == reduction_type, msg
148+
152149
if _parse_version(torch.__version__) >= (1, 8, 0):
153150
input_jacobians = torch.autograd.functional.jacobian(
154151
lambda out: loss_fn(out, targets), out, vectorize=vectorize

tests/influence/_utils/__init__.py

Whitespace-only changes.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# !/usr/bin/env python3
4+
5+
import torch
6+
7+
from captum.influence._utils.common import _jacobian_loss_wrt_inputs
8+
from tests.helpers import BaseTest
9+
from tests.helpers.basic import assertTensorAlmostEqual
10+
11+
12+
class TestCommon(BaseTest):
13+
def setUp(self) -> None:
14+
super().setUp()
15+
16+
def test_jacobian_loss_wrt_inputs(self) -> None:
17+
with self.assertRaises(ValueError) as err:
18+
_jacobian_loss_wrt_inputs(
19+
torch.nn.BCELoss(reduction="sum"),
20+
torch.tensor([-1.0, 1.0]),
21+
torch.tensor([1.0]),
22+
True,
23+
"",
24+
)
25+
self.assertEqual(
26+
"`` is not a valid value for reduction_type. "
27+
"Must be either 'sum' or 'mean'.",
28+
str(err.exception),
29+
)
30+
31+
with self.assertRaises(AssertionError) as err:
32+
_jacobian_loss_wrt_inputs(
33+
torch.nn.BCELoss(reduction="sum"),
34+
torch.tensor([-1.0, 1.0]),
35+
torch.tensor([1.0]),
36+
True,
37+
"mean",
38+
)
39+
self.assertEqual(
40+
"loss_fn.reduction `sum` does not matchreduction type `mean`."
41+
" Please ensure they are matching.",
42+
str(err.exception),
43+
)
44+
45+
res = _jacobian_loss_wrt_inputs(
46+
torch.nn.BCELoss(reduction="sum"),
47+
torch.tensor([0.5, 1.0]),
48+
torch.tensor([0.0, 1.0]),
49+
True,
50+
"sum",
51+
)
52+
assertTensorAlmostEqual(
53+
self, res, torch.tensor([2.0, 0.0]), delta=0.0, mode="sum"
54+
)

0 commit comments

Comments
 (0)