Skip to content

Commit dcd25eb

Browse files
Arm backend: Refactor pass tests for TOSA V1.0 (#10843)
PassPipelines will handle tosa_version aligned with other test pipelines. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 7993bb2 commit dcd25eb

21 files changed

+206
-198
lines changed

backends/arm/test/passes/test_cast_int64_pass.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,38 @@
66
from typing import Tuple
77

88
import torch
9-
from executorch.backends.arm._passes.cast_int64_pass import CastInt64BuffersToInt32Pass
9+
from executorch.backends.arm._passes import CastInt64BuffersToInt32Pass
1010

11+
from executorch.backends.arm.test import common
1112
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1213

1314
input_t = Tuple[torch.Tensor] # Input x
1415

1516

1617
class Int64Model(torch.nn.Module):
18+
test_data = {
19+
"rand": (torch.rand(4),),
20+
}
1721

1822
def forward(self, x: torch.Tensor):
1923
return x + 3
2024

21-
def get_inputs(self) -> input_t:
22-
return (torch.rand(4),)
23-
2425

25-
def test_int64_model_tosa_BI():
26+
@common.parametrize("test_data", Int64Model.test_data)
27+
def test_int64_model(test_data: input_t):
2628
module = Int64Model()
2729
op_checks = {
2830
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
2931
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
3032
}
3133
pipeline = PassPipeline[input_t](
3234
module,
33-
module.get_inputs(),
34-
tosa_version="TOSA-0.80+BI",
35+
test_data,
36+
quantize=False,
3537
ops_before_pass=op_checks,
3638
ops_after_pass=op_checks,
3739
passes_with_exported_program=[CastInt64BuffersToInt32Pass],
3840
)
39-
pipeline.pop_stage("quantize")
4041
pipeline.run()
4142

4243
exported_program = pipeline.tester.get_artifact("RunPasses").exported_program()

backends/arm/test/passes/test_convert_expand_copy_to_repeat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_expand_to_repeat_tosa_BI():
3535
pipeline = PassPipeline[input_t](
3636
module,
3737
module.get_inputs(),
38-
tosa_version="TOSA-0.80+BI",
38+
quantize=True,
3939
ops_before_pass={
4040
"executorch_exir_dialects_edge__ops_aten_expand_copy_default": 1,
4141
},

backends/arm/test/passes/test_convert_split_to_slice.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_split_to_slice_tosa_BI(module):
4949
pipeline = PassPipeline[input_t](
5050
module,
5151
module.get_inputs(),
52-
tosa_version="TOSA-0.80+BI",
52+
quantize=True,
5353
ops_before_pass={
5454
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 1,
5555
},

backends/arm/test/passes/test_convert_to_clamp.py

+60-48
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import unittest
6+
7+
from typing import Tuple
78

89
import torch
910
from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass
1011

1112
from executorch.backends.arm.test import common
12-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1314

14-
from executorch.backends.xnnpack.test.tester.tester import RunPasses
15+
input_t = Tuple[torch.Tensor] # Input x
1516

1617

1718
class HardTanh(torch.nn.Module):
19+
test_data = {"rand": (torch.rand(1, 64, 64, 3),)}
20+
1821
def __init__(self):
1922
super().__init__()
2023

@@ -23,11 +26,10 @@ def __init__(self):
2326
def forward(self, x):
2427
return self.hardtanh(x)
2528

26-
def get_inputs(self):
27-
return (torch.rand(1, 64, 64, 3),)
28-
2929

3030
class ReLU(torch.nn.Module):
31+
test_data = {"rand": (torch.rand(1, 64, 64, 3),)}
32+
3133
def __init__(self):
3234
super().__init__()
3335

@@ -36,45 +38,55 @@ def __init__(self):
3638
def forward(self, x):
3739
return self.relu(x)
3840

39-
def get_inputs(self):
40-
return (torch.rand(1, 64, 64, 3),)
41-
42-
43-
class TestConvertToClampPass(unittest.TestCase):
44-
"""
45-
Tests the ConvertToClampPass which converts hardtanh.default and relu.default to clamp.default
46-
"""
47-
48-
def test_tosa_MI_hardtahn(self):
49-
module = HardTanh()
50-
test_pass_stage = RunPasses([ConvertToClampPass])
51-
(
52-
ArmTester(
53-
module,
54-
example_inputs=module.get_inputs(),
55-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
56-
)
57-
.export()
58-
.to_edge()
59-
.check(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
60-
.run_passes(test_pass_stage)
61-
.check(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
62-
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
63-
)
64-
65-
def test_tosa_MI_relu(self):
66-
module = ReLU()
67-
test_pass_stage = RunPasses([ConvertToClampPass])
68-
(
69-
ArmTester(
70-
module,
71-
example_inputs=module.get_inputs(),
72-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
73-
)
74-
.export()
75-
.to_edge()
76-
.check(["executorch_exir_dialects_edge__ops_aten_relu_default"])
77-
.run_passes(test_pass_stage)
78-
.check(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
79-
.check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
80-
)
41+
42+
"""
43+
Tests the ConvertToClampPass which converts hardtanh.default and relu.default to clamp.default
44+
"""
45+
46+
47+
@common.parametrize("test_data", HardTanh.test_data)
48+
def test_tosa_MI_hardtahn(test_data: input_t):
49+
module = HardTanh()
50+
op_checks_before_pass = {
51+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1,
52+
}
53+
op_checks_after_pass = {
54+
"executorch_exir_dialects_edge__ops_aten_clamp_default": 1,
55+
}
56+
op_checks_not_after_pass = [
57+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
58+
]
59+
pipeline = PassPipeline[input_t](
60+
module,
61+
test_data,
62+
quantize=False,
63+
ops_before_pass=op_checks_before_pass,
64+
ops_after_pass=op_checks_after_pass,
65+
ops_not_after_pass=op_checks_not_after_pass,
66+
pass_list=[ConvertToClampPass],
67+
)
68+
pipeline.run()
69+
70+
71+
@common.parametrize("test_data", ReLU.test_data)
72+
def test_tosa_MI_relu(test_data: input_t):
73+
module = ReLU()
74+
op_checks_before_pass = {
75+
"executorch_exir_dialects_edge__ops_aten_relu_default": 1,
76+
}
77+
op_checks_after_pass = {
78+
"executorch_exir_dialects_edge__ops_aten_clamp_default": 1,
79+
}
80+
op_checks_not_after_pass = [
81+
"executorch_exir_dialects_edge__ops_aten_relu_default",
82+
]
83+
pipeline = PassPipeline[input_t](
84+
module,
85+
test_data,
86+
quantize=False,
87+
ops_before_pass=op_checks_before_pass,
88+
ops_after_pass=op_checks_after_pass,
89+
ops_not_after_pass=op_checks_not_after_pass,
90+
pass_list=[ConvertToClampPass],
91+
)
92+
pipeline.run()

backends/arm/test/passes/test_decompose_cosine_similarity_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ def test_decompose_cosine_similarity_tosa_BI(module):
4242
pipeline = PassPipeline[input_t](
4343
module,
4444
module.get_inputs(),
45-
tosa_version="TOSA-0.80+BI",
4645
ops_before_pass=None,
4746
ops_not_before_pass=None,
4847
ops_after_pass=ops_after_pass,
4948
ops_not_after_pass=None,
5049
pass_list=[DecomposeCosineSimilarityPass],
50+
quantize=True,
5151
)
5252
pipeline.run()

backends/arm/test/passes/test_decompose_div_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_decompose_div_tosa_MI(module):
4747
pipeline = PassPipeline[input_t](
4848
module,
4949
module.get_inputs(),
50-
tosa_version="TOSA-0.80+MI",
50+
quantize=False,
5151
ops_before_pass={
5252
"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1,
5353
},

backends/arm/test/passes/test_decompose_layernorm_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_decompose_layernorm_tosa_MI():
3737
pipeline = PassPipeline[input_t](
3838
module,
3939
module.get_inputs(),
40-
tosa_version="TOSA-0.80+MI",
40+
quantize=False,
4141
ops_before_pass={
4242
"executorch_exir_dialects_edge__ops_aten_native_layer_norm_default": 1,
4343
},

backends/arm/test/passes/test_decompose_meandim_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_decompose_meandim_tosa_MI(module):
5353
pipeline = PassPipeline[input_t](
5454
module,
5555
module.get_inputs(),
56-
tosa_version="TOSA-0.80+MI",
56+
quantize=False,
5757
ops_before_pass={
5858
"executorch_exir_dialects_edge__ops_aten_mean_dim": 1,
5959
},

backends/arm/test/passes/test_decompose_softmax_pass.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_softmax_basic_tosa_MI():
5252
pipeline = PassPipeline[input_t](
5353
module,
5454
module.get_inputs(),
55-
tosa_version="TOSA-0.80+MI",
55+
quantize=False,
5656
ops_before_pass={
5757
"executorch_exir_dialects_edge__ops_aten__softmax_default": 1,
5858
},
@@ -79,7 +79,7 @@ def test_softmax_log_tosa_MI():
7979
pipeline = PassPipeline[input_t](
8080
module,
8181
module.get_inputs(),
82-
tosa_version="TOSA-0.80+MI",
82+
quantize=False,
8383
ops_before_pass={
8484
"executorch_exir_dialects_edge__ops_aten__log_softmax_default": 1,
8585
},

backends/arm/test/passes/test_decompose_var_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_decompose_var_tosa_MI(module):
6060
pipeline = PassPipeline[input_t](
6161
module,
6262
module.get_inputs(),
63-
tosa_version="TOSA-0.80+MI",
63+
quantize=False,
6464
ops_before_pass={
6565
"executorch_exir_dialects_edge__ops_aten_var_correction": 1,
6666
},

backends/arm/test/passes/test_fold_qdq_pass.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,24 @@
77

88
import torch
99
from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass
10+
from executorch.backends.arm.test import common
1011
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1112

1213

1314
input_t = Tuple[torch.Tensor, torch.Tensor] # Input x, y
1415

1516

1617
class SimpleQuantizeModel(torch.nn.Module):
18+
test_data = {
19+
"rand": (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7)),
20+
}
21+
1722
def forward(self, x, y):
1823
return x + torch.max((x + x), (y + y))
1924

20-
def get_inputs(self) -> input_t:
21-
return (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7))
22-
2325

24-
def test_fold_qdq_pass_tosa_BI():
26+
@common.parametrize("test_data", SimpleQuantizeModel.test_data)
27+
def test_fold_qdq_pass_tosa_BI(test_data: input_t):
2528
"""
2629
Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into
2730
the node and stores the quantization parameters in meta.
@@ -32,8 +35,8 @@ def test_fold_qdq_pass_tosa_BI():
3235
module = SimpleQuantizeModel()
3336
pipeline = PassPipeline[input_t](
3437
module,
35-
module.get_inputs(),
36-
tosa_version="TOSA-0.80+BI",
38+
test_data,
39+
quantize=True,
3740
ops_before_pass={
3841
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7,
3942
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6,

backends/arm/test/passes/test_fuse_batchnorm_pass.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def forward(self, x):
136136

137137

138138
@common.parametrize("module", modules)
139-
def test_fuse_batchnorm_tosa_MI(module):
139+
def test_fuse_batchnorm_tosa_MI(module: torch.nn.Module):
140140
"""Test various cases where the batchnorm should and shouldn't be fused."""
141141
pipeline = PassPipeline[input_t](
142142
module,
143143
module.get_inputs(),
144-
tosa_version="TOSA-0.80+MI",
144+
quantize=False,
145145
ops_before_pass=module.ops_before_pass,
146146
ops_after_pass=module.ops_after_pass,
147147
passes_with_exported_program=[FuseBatchnorm2DPass],

backends/arm/test/passes/test_fuse_constant_ops_pass.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9898

9999

100100
@common.parametrize("module", modules)
101-
def test_fuse_const_ops_tosa_MI(module):
101+
def test_fuse_const_ops_tosa_MI(module: torch.nn.Module):
102102
pipeline = PassPipeline[input_t](
103103
module=module,
104104
test_data=(torch.rand(1),),
105-
tosa_version="TOSA-0.80+MI",
105+
quantize=False,
106106
ops_before_pass=module.ops_before_pass,
107107
ops_after_pass=module.ops_after_pass,
108108
ops_not_after_pass=module.ops_not_after_pass,
@@ -113,8 +113,13 @@ def test_fuse_const_ops_tosa_MI(module):
113113

114114
@unittest.skip("Test failing on internal CI")
115115
@common.parametrize("module", modules)
116-
def test_fuse_const_ops_tosa_BI(module):
116+
def test_fuse_const_ops_tosa_BI(module: torch.nn.Module):
117117
pipeline = TosaPipelineBI[input_t](
118-
module, (torch.rand(10, 10),), [], [], use_to_edge_transform_and_lower=True
118+
module,
119+
(torch.rand(10, 10),),
120+
[],
121+
[],
122+
quantize=True,
123+
use_to_edge_transform_and_lower=True,
119124
)
120125
pipeline.run()

backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_fuse_equal_placeholders_constants_tosa_MI():
6060
pipeline = PassPipeline[input_t](
6161
module,
6262
data,
63-
tosa_version="TOSA-0.80+MI",
63+
quantize=False,
6464
ops_before_pass=module.ops_before_pass,
6565
ops_after_pass=module.ops_after_pass,
6666
passes_with_exported_program=[FuseEqualPlaceholdersPass],
@@ -81,7 +81,7 @@ def test_fuse_equal_placeholders_state_dict_tosa_MI():
8181
pipeline = PassPipeline[input_t](
8282
module,
8383
data,
84-
tosa_version="TOSA-0.80+MI",
84+
quantize=False,
8585
ops_before_pass=module.ops_before_pass,
8686
ops_after_pass=module.ops_after_pass,
8787
passes_with_exported_program=[FuseEqualPlaceholdersPass],

0 commit comments

Comments
 (0)