Skip to content

Commit dab9447

Browse files
cicichen01facebook-github-bot
authored andcommitted
Fix the tests with exceptions assertions having multiple top-level statements (#1264)
Summary: Pull Request resolved: #1264 As titled. Add error message checks as well for test comprehensity and code readability. B908: Contexts with exceptions assertions like with self.assertRaises or with pytest.raises should not have multiple top-level statements. Each statement should be in its own context. That way, the test ensures that the exception is raised only in the exact statement where you expect it. Reviewed By: vivekmig Differential Revision: D55344319 fbshipit-source-id: aad315a15f764a9fb24d653c46c3940ae99248e9
1 parent 88f4b0a commit dab9447

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

tests/attr/test_common.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,56 @@
88

99
class Test(BaseTest):
1010
def test_validate_input(self) -> None:
11-
with self.assertRaises(AssertionError):
12-
_validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),))
11+
with self.assertRaises(AssertionError) as err:
12+
_validate_input(
13+
(torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0, 0.0, 1.0]),)
14+
)
15+
self.assertEqual(
16+
"Baseline can be provided as a tensor for just one input and "
17+
"broadcasted to the batch or input and baseline must have the "
18+
"same shape or the baseline corresponding to each input tensor "
19+
"must be a scalar. Found baseline: tensor([-2., 0., 1.]) and "
20+
"input: tensor([-1., 1.])",
21+
str(err.exception),
22+
)
23+
24+
with self.assertRaises(AssertionError) as err:
1325
_validate_input(
1426
(torch.tensor([-1.0, 1.0]),), (torch.tensor([-1.0, 1.0]),), n_steps=-1
1527
)
28+
self.assertEqual(
29+
"The number of steps must be a positive integer. Given: -1",
30+
str(err.exception),
31+
)
32+
33+
with self.assertRaises(AssertionError) as err:
1634
_validate_input(
1735
(torch.tensor([-1.0, 1.0]),),
1836
(torch.tensor([-1.0, 1.0]),),
1937
method="abcde",
2038
)
39+
self.assertIn(
40+
"Approximation method must be one for the following",
41+
str(err.exception),
42+
)
43+
# any baseline which is broadcastable to match the input is supported, which
44+
# includes a scalar / single-element tensor.
45+
_validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),))
2146
_validate_input((torch.tensor([-1.0]),), (torch.tensor([-2.0]),))
2247
_validate_input(
2348
(torch.tensor([-1.0]),), (torch.tensor([-2.0]),), method="gausslegendre"
2449
)
2550

2651
def test_validate_nt_type(self) -> None:
27-
with self.assertRaises(AssertionError):
52+
with self.assertRaises(
53+
AssertionError,
54+
) as err:
2855
_validate_noise_tunnel_type("abc", SUPPORTED_NOISE_TUNNEL_TYPES)
56+
self.assertIn(
57+
"Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`.",
58+
str(err.exception),
59+
)
60+
2961
_validate_noise_tunnel_type("smoothgrad", SUPPORTED_NOISE_TUNNEL_TYPES)
3062
_validate_noise_tunnel_type("smoothgrad_sq", SUPPORTED_NOISE_TUNNEL_TYPES)
3163
_validate_noise_tunnel_type("vargrad", SUPPORTED_NOISE_TUNNEL_TYPES)

0 commit comments

Comments
 (0)