Skip to content

Commit 06055f2

Browse files
3l1facebook-github-bot
authored andcommitted
Fixup op_slice negative start arguments
Summary: Fixup op_slice negative start arguments Reviewed By: digantdesai Differential Revision: D72728353
1 parent eeabc29 commit 06055f2

File tree

2 files changed

+61
-35
lines changed

2 files changed

+61
-35
lines changed

backends/arm/operators/op_slice.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ class SliceVisitor(NodeVisitor):
2323
def __init__(self, *args):
2424
super().__init__(*args)
2525

26+
def _fixup_start(self, start, shape, dim):
27+
if start.number < 0:
28+
return start.number % shape[dim]
29+
else:
30+
return start.number
31+
32+
def _fixup_end(self, end, shape, dim):
33+
if end.number < 0:
34+
return end.number % shape[dim]
35+
else:
36+
return min(end.number, shape[dim])
37+
2638
def define_node(
2739
self,
2840
node: Node,
@@ -42,17 +54,18 @@ def define_node(
4254
# Translate and check parameters in Pytorch dim order.
4355
shape = input_node.shape
4456
dim = dim.number
45-
if end.number < 0:
46-
end_index = end.number % shape[dim]
47-
else:
48-
end_index = min(end.number, shape[dim])
49-
size = end_index - start.number
57+
58+
start_index = self._fixup_start(start, shape, dim)
59+
end_index = self._fixup_end(end, shape, dim)
60+
size = end_index - start_index
61+
5062
assert size > 0
5163
assert size <= shape[dim]
5264

5365
# Convert aten args to Tosa's start and size attributes and in TOSA dim order.
5466
attr = ts.TosaSerializerAttribute()
55-
start_attr = [start.number if i == dim else 0 for i in input_node.dim_order]
67+
68+
start_attr = [self._fixup_start(start, shape, dim) if i == dim else 0 for i in input_node.dim_order]
5669
size_attr = [size if i == dim else shape[i] for i in input_node.dim_order]
5770
attr.SliceAttribute(start_attr, size_attr)
5871

backends/arm/test/ops/test_slice.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,36 @@
1616
from executorch.exir.backend.compile_spec_schema import CompileSpec
1717
from parameterized import parameterized
1818

19+
test_data_suite = [
20+
(
21+
torch.ones(10),
22+
[(3, -3)]
23+
),
24+
(
25+
torch.ones(10),
26+
[(-8, 3)]
27+
),
28+
(
29+
torch.ones(10, 10),
30+
[(1, 3), (3, None)]
31+
),
32+
(
33+
torch.ones(10, 10, 10),
34+
[(0, 7), (0, None), (0, 8)]
35+
),
36+
(
37+
torch.ones((1, 12, 10, 10)),
38+
[(None, None), (None, 5), (3, 5), (4, 10)]
39+
)
40+
]
1941

2042
class TestSimpleSlice(unittest.TestCase):
2143

2244
class Slice(torch.nn.Module):
23-
24-
sizes = [(10), (10, 10), (10, 10, 10), ((1, 12, 10, 10))]
25-
test_tensors = [(torch.ones(n),) for n in sizes]
26-
27-
def forward(self, x: torch.Tensor):
28-
if x.dim() == 1:
29-
return x[3:-3]
30-
elif x.dim() == 2:
31-
return x[1:3, 3:]
32-
elif x.dim() == 3:
33-
return x[0:7, 0:, 0:8]
34-
elif x.dim() == 4:
35-
return x[:, :5, 3:5, 4:10]
45+
def forward(self, x: torch.Tensor, s: list[tuple[int, int]]):
46+
slices = [slice(*i) for i in s]
47+
return x[slices]
48+
3649

3750
def _test_slice_tosa_MI_pipeline(
3851
self, module: torch.nn.Module, test_data: torch.Tensor
@@ -111,26 +124,26 @@ def _test_slice_u85_BI_pipeline(
111124
self._test_slice_ethos_BI_pipeline(
112125
common.get_u85_compile_spec(), module, test_data
113126
)
114-
115-
@parameterized.expand(Slice.test_tensors)
127+
128+
@parameterized.expand(test_data_suite)
116129
@pytest.mark.tosa_ref_model
117-
def test_slice_tosa_MI(self, tensor):
118-
self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor,))
130+
def test_slice_tosa_MI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]):
131+
self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor, slices))
119132

120-
@parameterized.expand(Slice.test_tensors[:2])
133+
@parameterized.expand(test_data_suite)
121134
@pytest.mark.tosa_ref_model
122-
def test_slice_nchw_tosa_BI(self, test_tensor: torch.Tensor):
123-
self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,))
135+
def test_slice_nchw_tosa_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]):
136+
self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices))
124137

125-
@parameterized.expand(Slice.test_tensors[2:])
138+
@parameterized.expand(test_data_suite)
126139
@pytest.mark.tosa_ref_model
127-
def test_slice_nhwc_tosa_BI(self, test_tensor: torch.Tensor):
128-
self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,))
140+
def test_slice_nhwc_tosa_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]):
141+
self._test_slice_tosa_BI_pipeline(self.Slice(), (tensor, slices))
129142

130-
@parameterized.expand(Slice.test_tensors)
131-
def test_slice_u55_BI(self, test_tensor: torch.Tensor):
132-
self._test_slice_u55_BI_pipeline(self.Slice(), (test_tensor,))
143+
@parameterized.expand(test_data_suite)
144+
def test_slice_u55_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]):
145+
self._test_slice_u55_BI_pipeline(self.Slice(), (tensor, slices))
133146

134-
@parameterized.expand(Slice.test_tensors)
135-
def test_slice_u85_BI(self, test_tensor: torch.Tensor):
136-
self._test_slice_u85_BI_pipeline(self.Slice(), (test_tensor,))
147+
@parameterized.expand(test_data_suite)
148+
def test_slice_u85_BI(self, tensor: torch.Tensor, slices: list[tuple[int, int]]):
149+
self._test_slice_u85_BI_pipeline(self.Slice(), (tensor, slices))

0 commit comments

Comments
 (0)