3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
- import unittest
6
+
7
+ from typing import Tuple
7
8
8
9
import torch
9
10
from executorch .backends .arm ._passes .convert_to_clamp import ConvertToClampPass
10
11
11
12
from executorch .backends .arm .test import common
12
- from executorch .backends .arm .test .tester .arm_tester import ArmTester
13
+ from executorch .backends .arm .test .tester .test_pipeline import PassPipeline
13
14
14
- from executorch . backends . xnnpack . test . tester . tester import RunPasses
15
+ input_t = Tuple [ torch . Tensor ] # Input x
15
16
16
17
17
18
class HardTanh (torch .nn .Module ):
19
+ test_data = {"rand" : (torch .rand (1 , 64 , 64 , 3 ),)}
20
+
18
21
def __init__ (self ):
19
22
super ().__init__ ()
20
23
@@ -23,11 +26,10 @@ def __init__(self):
23
26
def forward (self , x ):
24
27
return self .hardtanh (x )
25
28
26
- def get_inputs (self ):
27
- return (torch .rand (1 , 64 , 64 , 3 ),)
28
-
29
29
30
30
class ReLU (torch .nn .Module ):
31
+ test_data = {"rand" : (torch .rand (1 , 64 , 64 , 3 ),)}
32
+
31
33
def __init__ (self ):
32
34
super ().__init__ ()
33
35
@@ -36,45 +38,55 @@ def __init__(self):
36
38
def forward (self , x ):
37
39
return self .relu (x )
38
40
39
- def get_inputs (self ):
40
- return (torch .rand (1 , 64 , 64 , 3 ),)
41
-
42
-
43
- class TestConvertToClampPass (unittest .TestCase ):
44
- """
45
- Tests the ConvertToClampPass which converts hardtanh.default and relu.default to clamp.default
46
- """
47
-
48
- def test_tosa_MI_hardtahn (self ):
49
- module = HardTanh ()
50
- test_pass_stage = RunPasses ([ConvertToClampPass ])
51
- (
52
- ArmTester (
53
- module ,
54
- example_inputs = module .get_inputs (),
55
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" ),
56
- )
57
- .export ()
58
- .to_edge ()
59
- .check (["executorch_exir_dialects_edge__ops_aten_hardtanh_default" ])
60
- .run_passes (test_pass_stage )
61
- .check (["executorch_exir_dialects_edge__ops_aten_clamp_default" ])
62
- .check_not (["executorch_exir_dialects_edge__ops_aten_hardtanh_default" ])
63
- )
64
-
65
- def test_tosa_MI_relu (self ):
66
- module = ReLU ()
67
- test_pass_stage = RunPasses ([ConvertToClampPass ])
68
- (
69
- ArmTester (
70
- module ,
71
- example_inputs = module .get_inputs (),
72
- compile_spec = common .get_tosa_compile_spec ("TOSA-0.80+MI" ),
73
- )
74
- .export ()
75
- .to_edge ()
76
- .check (["executorch_exir_dialects_edge__ops_aten_relu_default" ])
77
- .run_passes (test_pass_stage )
78
- .check (["executorch_exir_dialects_edge__ops_aten_clamp_default" ])
79
- .check_not (["executorch_exir_dialects_edge__ops_aten_relu_default" ])
80
- )
41
+
42
+ """
43
+ Tests the ConvertToClampPass which converts hardtanh.default and relu.default to clamp.default
44
+ """
45
+
46
+
47
+ @common .parametrize ("test_data" , HardTanh .test_data )
48
+ def test_tosa_MI_hardtahn (test_data : input_t ):
49
+ module = HardTanh ()
50
+ op_checks_before_pass = {
51
+ "executorch_exir_dialects_edge__ops_aten_hardtanh_default" : 1 ,
52
+ }
53
+ op_checks_after_pass = {
54
+ "executorch_exir_dialects_edge__ops_aten_clamp_default" : 1 ,
55
+ }
56
+ op_checks_not_after_pass = [
57
+ "executorch_exir_dialects_edge__ops_aten_hardtanh_default" ,
58
+ ]
59
+ pipeline = PassPipeline [input_t ](
60
+ module ,
61
+ test_data ,
62
+ quantize = False ,
63
+ ops_before_pass = op_checks_before_pass ,
64
+ ops_after_pass = op_checks_after_pass ,
65
+ ops_not_after_pass = op_checks_not_after_pass ,
66
+ pass_list = [ConvertToClampPass ],
67
+ )
68
+ pipeline .run ()
69
+
70
+
71
+ @common .parametrize ("test_data" , ReLU .test_data )
72
+ def test_tosa_MI_relu (test_data : input_t ):
73
+ module = ReLU ()
74
+ op_checks_before_pass = {
75
+ "executorch_exir_dialects_edge__ops_aten_relu_default" : 1 ,
76
+ }
77
+ op_checks_after_pass = {
78
+ "executorch_exir_dialects_edge__ops_aten_clamp_default" : 1 ,
79
+ }
80
+ op_checks_not_after_pass = [
81
+ "executorch_exir_dialects_edge__ops_aten_relu_default" ,
82
+ ]
83
+ pipeline = PassPipeline [input_t ](
84
+ module ,
85
+ test_data ,
86
+ quantize = False ,
87
+ ops_before_pass = op_checks_before_pass ,
88
+ ops_after_pass = op_checks_after_pass ,
89
+ ops_not_after_pass = op_checks_not_after_pass ,
90
+ pass_list = [ConvertToClampPass ],
91
+ )
92
+ pipeline .run ()
0 commit comments