|
16 | 16 | from executorch.exir.backend.compile_spec_schema import CompileSpec |
17 | 17 | from parameterized import parameterized |
18 | 18 |
|
| 19 | +test_data_suite = [ |
| 20 | + ( |
| 21 | + torch.ones(10), |
| 22 | + [(3, -3)] |
| 23 | + ), |
| 24 | + ( |
| 25 | + torch.ones(10), |
| 26 | + [(-8, 3)] |
| 27 | + ), |
| 28 | + ( |
| 29 | + torch.ones(10, 10), |
| 30 | + [(1, 3), (3, None)] |
| 31 | + ), |
| 32 | + ( |
| 33 | + torch.ones(10, 10, 10), |
| 34 | + [(0, 7), (0, None), (0, 8)] |
| 35 | + ), |
| 36 | + ( |
| 37 | + torch.ones((1, 12, 10, 10)), |
| 38 | + [(None, None), (None, 5), (3, 5), (4, 10)] |
| 39 | + ) |
| 40 | +] |
19 | 41 |
|
20 | 42 | class TestSimpleSlice(unittest.TestCase): |
21 | 43 |
|
22 | 44 | class Slice(torch.nn.Module): |
23 | | - |
24 | | - sizes = [(10), (10, 10), (10, 10, 10), ((1, 12, 10, 10))] |
25 | | - test_tensors = [(torch.ones(n),) for n in sizes] |
26 | | - |
27 | | - def forward(self, x: torch.Tensor): |
28 | | - if x.dim() == 1: |
29 | | - return x[3:-3] |
30 | | - elif x.dim() == 2: |
31 | | - return x[1:3, 3:] |
32 | | - elif x.dim() == 3: |
33 | | - return x[0:7, 0:, 0:8] |
34 | | - elif x.dim() == 4: |
35 | | - return x[:, :5, 3:5, 4:10] |
| 45 | + def forward(self, x: torch.Tensor, s: list[tuple[int, int]]): |
| 46 | + slices = [slice(*i) for i in s] |
| 47 | + return x[slices] |
| 48 | + |
36 | 49 |
|
37 | 50 | def _test_slice_tosa_MI_pipeline( |
38 | 51 | self, module: torch.nn.Module, test_data: torch.Tensor |
@@ -111,26 +124,26 @@ def _test_slice_u85_BI_pipeline( |
111 | 124 | self._test_slice_ethos_BI_pipeline( |
112 | 125 | common.get_u85_compile_spec(), module, test_data |
113 | 126 | ) |
114 | | - |
115 | | - @parameterized.expand(Slice.test_tensors) |
| 127 | + |
| 128 | + @parameterized.expand(test_data_suite) |
116 | 129 | @pytest.mark.tosa_ref_model |
117 | | - def test_slice_tosa_MI(self, tensor): |
118 | | - self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor,)) |
| 130 | + def test_slice_tosa_MI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): |
| 131 | + self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor, slices)) |
119 | 132 |
|
120 | | - @parameterized.expand(Slice.test_tensors[:2]) |
| 133 | + @parameterized.expand(test_data_suite) |
121 | 134 | @pytest.mark.tosa_ref_model |
122 | | - def test_slice_nchw_tosa_BI(self, test_tensor: torch.Tensor): |
123 | | - self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,)) |
| 135 | + def test_slice_nchw_tosa_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): |
| 136 | + self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices)) |
124 | 137 |
|
125 | | - @parameterized.expand(Slice.test_tensors[2:]) |
| 138 | + @parameterized.expand(test_data_suite) |
126 | 139 | @pytest.mark.tosa_ref_model |
127 | | - def test_slice_nhwc_tosa_BI(self, test_tensor: torch.Tensor): |
128 | | - self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,)) |
| 140 | + def test_slice_nhwc_tosa_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): |
| 141 | + self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices)) |
129 | 142 |
|
130 | | - @parameterized.expand(Slice.test_tensors) |
131 | | - def test_slice_u55_BI(self, test_tensor: torch.Tensor): |
132 | | - self._test_slice_u55_BI_pipeline(self.Slice(), (test_tensor,)) |
| 143 | + @parameterized.expand(test_data_suite) |
| 144 | + def test_slice_u55_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): |
| 145 | + self._test_slice_u55_BI_pipeline(self.Slice(), (tensor, slices)) |
133 | 146 |
|
134 | | - @parameterized.expand(Slice.test_tensors) |
135 | | - def test_slice_u85_BI(self, test_tensor: torch.Tensor): |
136 | | - self._test_slice_u85_BI_pipeline(self.Slice(), (test_tensor,)) |
| 147 | + @parameterized.expand(test_data_suite) |
| 148 | + def test_slice_u85_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]): |
| 149 | + self._test_slice_u85_BI_pipeline(self.Slice(), (tensor, slices)) |
0 commit comments