Skip to content

Commit f13a82f

Browse files
mcr229facebook-github-bot
authored andcommitted
Move remove get item pass
Summary: Moving the tests for remove get item pass to use the new testing infra Differential Revision: D49718911 fbshipit-source-id: d961dc60d8d7494636bc29be57575a37afa53534
1 parent 346d6ba commit f13a82f

File tree

2 files changed

+100
-68
lines changed

2 files changed

+100
-68
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass
11+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
12+
13+
14+
class TestRemoveGetItemPass(unittest.TestCase):
15+
PassStage = RunPasses([RemoveGetItemPass])
16+
max_pool2d_name = "executorch_exir_dialects_edge__ops_aten_max_pool2d_default"
17+
amax_name = "executorch_exir_dialects_edge__ops_aten_amax_default"
18+
19+
class MaxPool2dModule(torch.nn.Module):
20+
def __init__(
21+
self,
22+
kernel_size=3,
23+
stride=1,
24+
padding=0,
25+
dilation=1,
26+
):
27+
super().__init__()
28+
self.max_pool2d_module = torch.nn.MaxPool2d(
29+
kernel_size=kernel_size,
30+
stride=stride,
31+
padding=padding,
32+
dilation=dilation,
33+
)
34+
35+
def forward(self, x):
36+
return self.max_pool2d_module(x)
37+
38+
def test_fp32_max_pool2d_remove_getitem(self):
39+
(
40+
Tester(self.MaxPool2dModule(), (torch.randn(4, 3, 24, 24),))
41+
.export()
42+
.to_edge()
43+
.run_passes(self.PassStage)
44+
.check_count({self.max_pool2d_name: 1})
45+
.run_method()
46+
.compare_outputs()
47+
)
48+
49+
def test_q8_max_pool2d_remove_getitem(self):
50+
(
51+
Tester(self.MaxPool2dModule(), (torch.randn(4, 3, 24, 24),))
52+
.quantize()
53+
.export()
54+
.to_edge()
55+
.run_passes(self.PassStage)
56+
.check_count({self.max_pool2d_name: 1})
57+
.run_method()
58+
.compare_outputs()
59+
)
60+
61+
class MaxModule(torch.nn.Module):
62+
def __init__(
63+
self,
64+
):
65+
super().__init__()
66+
67+
def forward(self, x):
68+
max_vals, indices = torch.max(x, dim=2, keepdim=True)
69+
return max_vals
70+
71+
def test_fp32_max_remove_getitem(self):
72+
(
73+
Tester(self.MaxModule(), (torch.randn(4, 3, 24, 24),))
74+
.export()
75+
.to_edge()
76+
.run_passes(self.PassStage)
77+
.check_count(
78+
{
79+
self.amax_name: 1,
80+
}
81+
)
82+
.run_method()
83+
.compare_outputs()
84+
)
85+
86+
def test_q8_max_remove_getitem(self):
87+
(
88+
Tester(self.MaxModule(), (torch.randn(4, 3, 24, 24),))
89+
.quantize()
90+
.export()
91+
.to_edge()
92+
.run_passes(self.PassStage)
93+
.check_count(
94+
{
95+
self.amax_name: 1,
96+
}
97+
)
98+
.run_method()
99+
.compare_outputs()
100+
)

backends/xnnpack/test/test_xnnpack_passes.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from executorch import exir
1212
from executorch.backends.xnnpack.passes import XNNPACKPassManager
1313
from executorch.backends.xnnpack.passes.convert_to_linear import ConvertToLinearPass
14-
from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass
1514
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1615

1716
from executorch.backends.xnnpack.utils.configs import get_xnnpack_capture_config
@@ -89,73 +88,6 @@ def capture_and_test_pass(
8988
)
9089
return new_exported_program
9190

92-
def test_max_pool2d_remove_getitem(self) -> None:
93-
passes = [RemoveGetItemPass()]
94-
95-
class MaxPool2dModule(torch.nn.Module):
96-
def __init__(
97-
self,
98-
kernel_size=3,
99-
stride=1,
100-
padding=0,
101-
dilation=1,
102-
):
103-
super().__init__()
104-
self.max_pool2d_module = torch.nn.MaxPool2d(
105-
kernel_size=kernel_size,
106-
stride=stride,
107-
padding=padding,
108-
dilation=dilation,
109-
)
110-
111-
def forward(self, x):
112-
return self.max_pool2d_module(x)
113-
114-
maxpool2d_module = MaxPool2dModule(3, 1, 0, 1)
115-
model_inputs = (torch.randn(4, 3, 24, 24),)
116-
117-
edge_ep = capture_graph_for_xnnpack(maxpool2d_module.eval(), model_inputs)
118-
new_ep = edge_ep.transform(*passes)
119-
result1 = edge_ep(model_inputs[0])[0]
120-
result2 = new_ep(model_inputs[0])[0]
121-
122-
# Filecheck exir_ops.edge.aten.max_pool2d.default node.
123-
FileCheck().check_count(
124-
"executorch_exir_dialects_edge__ops_aten_max_pool2d_default",
125-
1,
126-
exactly=True,
127-
).run(new_ep.exported_program.graph_module.code)
128-
129-
self.assertTrue(torch.allclose(result1, result2))
130-
131-
def test_max_remove_getitem(self) -> None:
132-
passes = [RemoveGetItemPass()]
133-
134-
class MaxModule(torch.nn.Module):
135-
def __init__(
136-
self,
137-
):
138-
super().__init__()
139-
140-
def forward(self, x):
141-
max_vals, indices = torch.max(x, dim=2, keepdim=True)
142-
return max_vals
143-
144-
max_module = MaxModule()
145-
model_inputs = (torch.randn(4, 3, 24, 24),)
146-
147-
edge_ep = capture_graph_for_xnnpack(max_module.eval(), model_inputs)
148-
149-
new_ep = edge_ep.transform(*passes)
150-
result1 = edge_ep(model_inputs[0])[0]
151-
result2 = new_ep(model_inputs[0])[0]
152-
153-
# Filecheck exir_ops.edge.aten.amax.default node.
154-
FileCheck().check_count(
155-
"executorch_exir_dialects_edge__ops_aten_amax_default", 1, exactly=True
156-
).run(new_ep.exported_program.graph_module.code)
157-
158-
self.assertTrue(torch.allclose(result1, result2))
15991

16092
# TODO T154127848: Move this out of XNNPACK dir and into cannonical_partitioner dir
16193
def test_duplicate_dequant_node_pass(self) -> None:

0 commit comments

Comments
 (0)