5
5
import unittest
6
6
7
7
import onnx .defs
8
- import onnx .parser
9
8
10
9
from onnxscript import ir , version_converter
11
10
@@ -43,7 +42,7 @@ def test_upstream_coverage(self):
43
42
self .assertIn ((name , upgrade_version ), op_upgrades )
44
43
45
44
def test_version_convert_non_standard_onnx_domain (self ):
46
- model_proto = onnx . parser . parse_model (
45
+ model = ir . from_onnx_text (
47
46
"""
48
47
<ir_version: 7, opset_import: [ "local" : 1]>
49
48
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
@@ -58,7 +57,6 @@ def test_version_convert_non_standard_onnx_domain(self):
58
57
}
59
58
"""
60
59
)
61
- model = ir .serde .deserialize_model (model_proto )
62
60
self .assertEqual (model .graph .node (4 ).op_type , "GridSample" )
63
61
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "bilinear" )
64
62
@@ -76,7 +74,7 @@ def test_version_convert_non_standard_onnx_domain(self):
76
74
77
75
class VersionConverter18to17Test (unittest .TestCase ):
78
76
def test_version_convert_compatible (self ):
79
- model_proto = onnx . parser . parse_model (
77
+ model = ir . from_onnx_text (
80
78
"""
81
79
<ir_version: 7, opset_import: [ "" : 18]>
82
80
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
@@ -91,14 +89,13 @@ def test_version_convert_compatible(self):
91
89
}
92
90
"""
93
91
)
94
- model = ir .serde .deserialize_model (model_proto )
95
92
target_version = 17
96
93
version_converter .convert_version (model , target_version = target_version )
97
94
98
95
99
96
class VersionConverter18to19Test (unittest .TestCase ):
100
97
def test_version_convert_compatible (self ):
101
- model_proto = onnx . parser . parse_model (
98
+ model = ir . from_onnx_text (
102
99
"""
103
100
<ir_version: 7, opset_import: [ "" : 18]>
104
101
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
@@ -113,7 +110,6 @@ def test_version_convert_compatible(self):
113
110
}
114
111
"""
115
112
)
116
- model = ir .serde .deserialize_model (model_proto )
117
113
target_version = 19
118
114
version_converter .convert_version (model , target_version = target_version )
119
115
@@ -127,7 +123,7 @@ def test_version_convert_compatible(self):
127
123
128
124
class VersionConverter19to20Test (unittest .TestCase ):
129
125
def test_version_convert_compatible (self ):
130
- model_proto = onnx . parser . parse_model (
126
+ model = ir . from_onnx_text (
131
127
"""
132
128
<ir_version: 7, opset_import: [ "" : 18]>
133
129
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
@@ -140,7 +136,6 @@ def test_version_convert_compatible(self):
140
136
}
141
137
"""
142
138
)
143
- model = ir .serde .deserialize_model (model_proto )
144
139
target_version = 20
145
140
version_converter .convert_version (model , target_version = target_version )
146
141
@@ -155,7 +150,7 @@ def test_version_convert_compatible(self):
155
150
self .assertEqual (len (model .graph .node (3 ).inputs ), 2 )
156
151
157
152
def test_version_convert_gridsample_linear (self ):
158
- model_proto = onnx . parser . parse_model (
153
+ model = ir . from_onnx_text (
159
154
"""
160
155
<ir_version: 7, opset_import: [ "" : 18]>
161
156
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
@@ -170,7 +165,6 @@ def test_version_convert_gridsample_linear(self):
170
165
}
171
166
"""
172
167
)
173
- model = ir .serde .deserialize_model (model_proto )
174
168
self .assertEqual (model .graph .node (4 ).op_type , "GridSample" )
175
169
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "bilinear" )
176
170
@@ -186,7 +180,7 @@ def test_version_convert_gridsample_linear(self):
186
180
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "linear" )
187
181
188
182
def test_version_convert_gridsample_cubic (self ):
189
- model_proto = onnx . parser . parse_model (
183
+ model = ir . from_onnx_text (
190
184
"""
191
185
<ir_version: 7, opset_import: [ "" : 18]>
192
186
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
@@ -201,7 +195,6 @@ def test_version_convert_gridsample_cubic(self):
201
195
}
202
196
"""
203
197
)
204
- model = ir .serde .deserialize_model (model_proto )
205
198
self .assertEqual (model .graph .node (4 ).op_type , "GridSample" )
206
199
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "bicubic" )
207
200
@@ -217,7 +210,7 @@ def test_version_convert_gridsample_cubic(self):
217
210
self .assertEqual (model .graph .node (4 ).attributes ["mode" ].value , "cubic" )
218
211
219
212
def test_version_convert_inline (self ):
220
- model_proto = onnx . parser . parse_model (
213
+ model = ir . from_onnx_text (
221
214
"""
222
215
<ir_version: 8, opset_import: [ "" : 18]>
223
216
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output)
@@ -236,7 +229,6 @@ def test_version_convert_inline(self):
236
229
}
237
230
"""
238
231
)
239
- model = ir .serde .deserialize_model (model_proto )
240
232
target_version = 20
241
233
version_converter .convert_version (model , target_version = target_version )
242
234
@@ -254,7 +246,7 @@ def test_version_convert_inline(self):
254
246
255
247
class VersionConverter20to21Test (unittest .TestCase ):
256
248
def test_version_groupnorm (self ):
257
- model_proto = onnx . parser . parse_model (
249
+ model = ir . from_onnx_text (
258
250
"""
259
251
<ir_version: 7, opset_import: [ "" : 18]>
260
252
agraph (float[1, 4, 512, 512] input_x, float[2] scale, float[2] bias) => (float[4, 512, 512] output)
@@ -265,7 +257,6 @@ def test_version_groupnorm(self):
265
257
}
266
258
"""
267
259
)
268
- model = ir .serde .deserialize_model (model_proto )
269
260
target_version = 21
270
261
version_converter .convert_version (model , target_version = target_version )
271
262
@@ -285,7 +276,7 @@ def test_version_groupnorm(self):
285
276
self .assertEqual (model .graph .node (9 ).version , 21 )
286
277
287
278
def test_version_groupnorm_no_bias (self ):
288
- model_proto = onnx . parser . parse_model (
279
+ model = ir . from_onnx_text (
289
280
"""
290
281
<ir_version: 7, opset_import: [ "" : 18]>
291
282
agraph (float[1, 4, 512, 512] input_x, float[2] scale) => (float[4, 512, 512] output)
@@ -296,7 +287,6 @@ def test_version_groupnorm_no_bias(self):
296
287
}
297
288
"""
298
289
)
299
- model = ir .serde .deserialize_model (model_proto )
300
290
target_version = 21
301
291
version_converter .convert_version (model , target_version = target_version )
302
292
@@ -306,7 +296,7 @@ def test_version_groupnorm_no_bias(self):
306
296
307
297
class VersionConverter23to24Test (unittest .TestCase ):
308
298
def test_version_convert_compatible (self ):
309
- model_proto = onnx . parser . parse_model (
299
+ model = ir . from_onnx_text (
310
300
"""
311
301
<ir_version: 7, opset_import: [ "" : 23]>
312
302
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
@@ -321,7 +311,6 @@ def test_version_convert_compatible(self):
321
311
}
322
312
"""
323
313
)
324
- model = ir .serde .deserialize_model (model_proto )
325
314
target_version = 24
326
315
version_converter .convert_version (model , target_version = target_version )
327
316
0 commit comments