|
7 | 7 |
|
8 | 8 |
|
9 | 9 | class TestNegConverter(DispatchTestCase): |
10 | | - def test_neg(self): |
| 10 | + @parameterized.expand( |
| 11 | + [ |
| 12 | + ("2d_dim_dtype_float", (2, 2), torch.float), |
| 13 | + ("3d_dim_dtype_float", (2, 2, 2), torch.float), |
| 14 | + |
| 15 | + ] |
| 16 | + ) |
| 17 | + def test_neg_float(self, _, x, type): |
11 | 18 | class neg(nn.Module): |
12 | 19 | def forward(self, input): |
13 | 20 | return torch.neg(input) |
14 | | - |
15 | | - inputs = [torch.randn(1, 10)] |
| 21 | + |
| 22 | + inputs = [torch.randn(x, dtype=type)] |
16 | 23 | self.run_test( |
17 | 24 | neg(), |
18 | 25 | inputs, |
19 | 26 | expected_ops={torch.ops.aten.neg.default}, |
20 | 27 | ) |
21 | 28 |
|
| 29 | + @parameterized.expand( |
| 30 | + [ |
| 31 | + ("2d_dim_dtype_int", (2, 2), torch.int32, 0, 5), |
| 32 | + ("3d_dim_dtype_int", (2, 2, 2), torch.int32, 0, 5), |
| 33 | + ] |
| 34 | + ) |
| 35 | + |
| 36 | + def test_neg_int(self, _, x, type, min, max): |
| 37 | + class neg(nn.Module): |
| 38 | + def forward(self, input): |
| 39 | + return torch.neg(input) |
| 40 | + |
| 41 | + inputs = [torch.randint(min, max, (x), dtype=type)] |
| 42 | + |
| 43 | + self.run_test( |
| 44 | + neg(), |
| 45 | + inputs, |
| 46 | + expected_ops={torch.ops.aten.neg.default}, |
| 47 | + ) |
22 | 48 |
|
23 | 49 | if __name__ == "__main__": |
24 | 50 | run_tests() |
0 commit comments