|
8 | 8 |
|
9 | 9 | class Test(BaseTest): |
10 | 10 | def test_validate_input(self) -> None: |
11 | | - with self.assertRaises(AssertionError): |
12 | | - _validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),)) |
| 11 | + with self.assertRaises(AssertionError) as err: |
| 12 | + _validate_input( |
| 13 | + (torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0, 0.0, 1.0]),) |
| 14 | + ) |
| 15 | + self.assertEqual( |
| 16 | + "Baseline can be provided as a tensor for just one input and " |
| 17 | + "broadcasted to the batch or input and baseline must have the " |
| 18 | + "same shape or the baseline corresponding to each input tensor " |
| 19 | + "must be a scalar. Found baseline: tensor([-2., 0., 1.]) and " |
| 20 | + "input: tensor([-1., 1.])", |
| 21 | + str(err.exception), |
| 22 | + ) |
| 23 | + |
| 24 | + with self.assertRaises(AssertionError) as err: |
13 | 25 | _validate_input( |
14 | 26 | (torch.tensor([-1.0, 1.0]),), (torch.tensor([-1.0, 1.0]),), n_steps=-1 |
15 | 27 | ) |
| 28 | + self.assertEqual( |
| 29 | + "The number of steps must be a positive integer. Given: -1", |
| 30 | + str(err.exception), |
| 31 | + ) |
| 32 | + |
| 33 | + with self.assertRaises(AssertionError) as err: |
16 | 34 | _validate_input( |
17 | 35 | (torch.tensor([-1.0, 1.0]),), |
18 | 36 | (torch.tensor([-1.0, 1.0]),), |
19 | 37 | method="abcde", |
20 | 38 | ) |
| 39 | + self.assertIn( |
| 40 | + "Approximation method must be one for the following", |
| 41 | + str(err.exception), |
| 42 | + ) |
| 43 | + # any baseline which is broadcastable to match the input is supported, which |
| 44 | + # includes a scalar / single-element tensor. |
| 45 | + _validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),)) |
21 | 46 | _validate_input((torch.tensor([-1.0]),), (torch.tensor([-2.0]),)) |
22 | 47 | _validate_input( |
23 | 48 | (torch.tensor([-1.0]),), (torch.tensor([-2.0]),), method="gausslegendre" |
24 | 49 | ) |
25 | 50 |
|
26 | 51 | def test_validate_nt_type(self) -> None: |
27 | | - with self.assertRaises(AssertionError): |
| 52 | + with self.assertRaises( |
| 53 | + AssertionError, |
| 54 | + ) as err: |
28 | 55 | _validate_noise_tunnel_type("abc", SUPPORTED_NOISE_TUNNEL_TYPES) |
| 56 | + self.assertIn( |
| 57 | + "Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`.", |
| 58 | + str(err.exception), |
| 59 | + ) |
| 60 | + |
29 | 61 | _validate_noise_tunnel_type("smoothgrad", SUPPORTED_NOISE_TUNNEL_TYPES) |
30 | 62 | _validate_noise_tunnel_type("smoothgrad_sq", SUPPORTED_NOISE_TUNNEL_TYPES) |
31 | 63 | _validate_noise_tunnel_type("vargrad", SUPPORTED_NOISE_TUNNEL_TYPES) |
0 commit comments