Skip to content

Commit 9543c24

Browse files
authored
Add from_onnx_text function to convert ONNX text to IR model (#2291)
Simplify conversion of the onnx textual representation to IR. #2290
1 parent 99cdedd commit 9543c24

File tree

7 files changed

+28
-38
lines changed

7 files changed

+28
-38
lines changed

docs/ir/ir_api/core.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ir.load
1717
ir.save
1818
ir.from_proto
19+
ir.from_onnx_text
1920
ir.to_proto
2021
ir.tensor
2122
ir.node

onnxscript/ir/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
"TensorProtoTensor",
7272
# Conversion functions
7373
"from_proto",
74+
"from_onnx_text",
7475
"to_proto",
7576
# Convenience constructors
7677
"tensor",
@@ -144,7 +145,7 @@
144145
TypeProtocol,
145146
ValueProtocol,
146147
)
147-
from onnxscript.ir.serde import TensorProtoTensor, from_proto, to_proto
148+
from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
148149

149150

150151
def __set_module() -> None:

onnxscript/ir/passes/common/inliner_test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Callable, Sequence
99

1010
import onnx
11-
from onnx import parser
1211

1312
from onnxscript import ir
1413
from onnxscript.ir.passes.common import inliner
@@ -44,14 +43,12 @@ def _check(
4443
self, input_model: str, expected_model: str, renameable: Sequence[str] | None = None
4544
) -> None:
4645
name_check = _name_checker(renameable)
47-
model_proto = parser.parse_model(input_model)
48-
model_ir = ir.serde.deserialize_model(model_proto)
46+
model_ir = ir.from_onnx_text(input_model)
4947
inliner.InlinePass()(model_ir)
5048
proto = ir.serde.serialize_model(model_ir)
5149
text = onnx.printer.to_text(proto)
5250
print(text)
53-
expected_proto = parser.parse_model(expected_model)
54-
expected_ir = ir.serde.deserialize_model(expected_proto)
51+
expected_ir = ir.from_onnx_text(expected_model)
5552
self.assertEqual(len(model_ir.graph), len(expected_ir.graph))
5653
for node, expected_node in zip(model_ir.graph, expected_ir.graph):
5754
# TODO: handle node renaming

onnxscript/ir/serde.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"TensorProtoTensor",
2222
# Deserialization
2323
"from_proto",
24+
"from_onnx_text",
2425
"deserialize_attribute",
2526
"deserialize_dimension",
2627
"deserialize_function",
@@ -190,6 +191,15 @@ def from_proto(proto: object) -> object:
190191
)
191192

192193

194+
def from_onnx_text(model_text: str, /) -> _core.Model:
195+
"""Convert the ONNX textual representation to an IR model.
196+
197+
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
198+
"""
199+
proto = onnx.parser.parse_model(model_text)
200+
return deserialize_model(proto)
201+
202+
193203
@typing.overload
194204
def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap]
195205
@typing.overload

onnxscript/optimizer/_constant_folding_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,10 @@
1313
from onnxscript.optimizer import _constant_folding
1414

1515

16-
def _create_model(model_text: str) -> ir.Model:
17-
"""Create a model from the given text."""
18-
model = onnx.parser.parse_model(model_text)
19-
return ir.serde.deserialize_model(model)
20-
21-
2216
class FoldConstantsTest(unittest.TestCase):
2317
def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs):
2418
if isinstance(model, str):
25-
model = _create_model(model)
19+
model = ir.from_onnx_text(model)
2620
_constant_folding.fold_constants(
2721
model, onnx_shape_inference=onnx_shape_inference, **kwargs
2822
)
@@ -552,7 +546,7 @@ def test_large_transpose(self):
552546
z = MatMul (x, wt)
553547
}
554548
"""
555-
model = _create_model(model_text)
549+
model = ir.from_onnx_text(model_text)
556550
w = model.graph.initializers["w"]
557551
w.shape = ir.Shape([512, 256])
558552
w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32))

onnxscript/rewriter/no_op_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Licensed under the MIT License.
33
import unittest
44

5-
import onnx.parser
65
import parameterized
76

87
from onnxscript import ir
@@ -11,8 +10,7 @@
1110

1211
class NoOpTest(unittest.TestCase):
1312
def _check(self, model_text: str) -> None:
14-
model_proto = onnx.parser.parse_model(model_text)
15-
model = ir.serde.deserialize_model(model_proto)
13+
model = ir.from_onnx_text(model_text)
1614
count = no_op.rules.apply_to_model(model)
1715
self.assertEqual(count, 1)
1816
self.assertEqual(model.graph[-1].op_type, "Identity")

onnxscript/version_converter/_version_converter_test.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import unittest
66

77
import onnx.defs
8-
import onnx.parser
98

109
from onnxscript import ir, version_converter
1110

@@ -43,7 +42,7 @@ def test_upstream_coverage(self):
4342
self.assertIn((name, upgrade_version), op_upgrades)
4443

4544
def test_version_convert_non_standard_onnx_domain(self):
46-
model_proto = onnx.parser.parse_model(
45+
model = ir.from_onnx_text(
4746
"""
4847
<ir_version: 7, opset_import: [ "local" : 1]>
4948
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
@@ -58,7 +57,6 @@ def test_version_convert_non_standard_onnx_domain(self):
5857
}
5958
"""
6059
)
61-
model = ir.serde.deserialize_model(model_proto)
6260
self.assertEqual(model.graph.node(4).op_type, "GridSample")
6361
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")
6462

@@ -76,7 +74,7 @@ def test_version_convert_non_standard_onnx_domain(self):
7674

7775
class VersionConverter18to17Test(unittest.TestCase):
7876
def test_version_convert_compatible(self):
79-
model_proto = onnx.parser.parse_model(
77+
model = ir.from_onnx_text(
8078
"""
8179
<ir_version: 7, opset_import: [ "" : 18]>
8280
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
@@ -91,14 +89,13 @@ def test_version_convert_compatible(self):
9189
}
9290
"""
9391
)
94-
model = ir.serde.deserialize_model(model_proto)
9592
target_version = 17
9693
version_converter.convert_version(model, target_version=target_version)
9794

9895

9996
class VersionConverter18to19Test(unittest.TestCase):
10097
def test_version_convert_compatible(self):
101-
model_proto = onnx.parser.parse_model(
98+
model = ir.from_onnx_text(
10299
"""
103100
<ir_version: 7, opset_import: [ "" : 18]>
104101
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
@@ -113,7 +110,6 @@ def test_version_convert_compatible(self):
113110
}
114111
"""
115112
)
116-
model = ir.serde.deserialize_model(model_proto)
117113
target_version = 19
118114
version_converter.convert_version(model, target_version=target_version)
119115

@@ -127,7 +123,7 @@ def test_version_convert_compatible(self):
127123

128124
class VersionConverter19to20Test(unittest.TestCase):
129125
def test_version_convert_compatible(self):
130-
model_proto = onnx.parser.parse_model(
126+
model = ir.from_onnx_text(
131127
"""
132128
<ir_version: 7, opset_import: [ "" : 18]>
133129
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
@@ -140,7 +136,6 @@ def test_version_convert_compatible(self):
140136
}
141137
"""
142138
)
143-
model = ir.serde.deserialize_model(model_proto)
144139
target_version = 20
145140
version_converter.convert_version(model, target_version=target_version)
146141

@@ -155,7 +150,7 @@ def test_version_convert_compatible(self):
155150
self.assertEqual(len(model.graph.node(3).inputs), 2)
156151

157152
def test_version_convert_gridsample_linear(self):
158-
model_proto = onnx.parser.parse_model(
153+
model = ir.from_onnx_text(
159154
"""
160155
<ir_version: 7, opset_import: [ "" : 18]>
161156
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
@@ -170,7 +165,6 @@ def test_version_convert_gridsample_linear(self):
170165
}
171166
"""
172167
)
173-
model = ir.serde.deserialize_model(model_proto)
174168
self.assertEqual(model.graph.node(4).op_type, "GridSample")
175169
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")
176170

@@ -186,7 +180,7 @@ def test_version_convert_gridsample_linear(self):
186180
self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear")
187181

188182
def test_version_convert_gridsample_cubic(self):
189-
model_proto = onnx.parser.parse_model(
183+
model = ir.from_onnx_text(
190184
"""
191185
<ir_version: 7, opset_import: [ "" : 18]>
192186
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
@@ -201,7 +195,6 @@ def test_version_convert_gridsample_cubic(self):
201195
}
202196
"""
203197
)
204-
model = ir.serde.deserialize_model(model_proto)
205198
self.assertEqual(model.graph.node(4).op_type, "GridSample")
206199
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bicubic")
207200

@@ -217,7 +210,7 @@ def test_version_convert_gridsample_cubic(self):
217210
self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic")
218211

219212
def test_version_convert_inline(self):
220-
model_proto = onnx.parser.parse_model(
213+
model = ir.from_onnx_text(
221214
"""
222215
<ir_version: 8, opset_import: [ "" : 18]>
223216
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output)
@@ -236,7 +229,6 @@ def test_version_convert_inline(self):
236229
}
237230
"""
238231
)
239-
model = ir.serde.deserialize_model(model_proto)
240232
target_version = 20
241233
version_converter.convert_version(model, target_version=target_version)
242234

@@ -254,7 +246,7 @@ def test_version_convert_inline(self):
254246

255247
class VersionConverter20to21Test(unittest.TestCase):
256248
def test_version_groupnorm(self):
257-
model_proto = onnx.parser.parse_model(
249+
model = ir.from_onnx_text(
258250
"""
259251
<ir_version: 7, opset_import: [ "" : 18]>
260252
agraph (float[1, 4, 512, 512] input_x, float[2] scale, float[2] bias) => (float[4, 512, 512] output)
@@ -265,7 +257,6 @@ def test_version_groupnorm(self):
265257
}
266258
"""
267259
)
268-
model = ir.serde.deserialize_model(model_proto)
269260
target_version = 21
270261
version_converter.convert_version(model, target_version=target_version)
271262

@@ -285,7 +276,7 @@ def test_version_groupnorm(self):
285276
self.assertEqual(model.graph.node(9).version, 21)
286277

287278
def test_version_groupnorm_no_bias(self):
288-
model_proto = onnx.parser.parse_model(
279+
model = ir.from_onnx_text(
289280
"""
290281
<ir_version: 7, opset_import: [ "" : 18]>
291282
agraph (float[1, 4, 512, 512] input_x, float[2] scale) => (float[4, 512, 512] output)
@@ -296,7 +287,6 @@ def test_version_groupnorm_no_bias(self):
296287
}
297288
"""
298289
)
299-
model = ir.serde.deserialize_model(model_proto)
300290
target_version = 21
301291
version_converter.convert_version(model, target_version=target_version)
302292

@@ -306,7 +296,7 @@ def test_version_groupnorm_no_bias(self):
306296

307297
class VersionConverter23to24Test(unittest.TestCase):
308298
def test_version_convert_compatible(self):
309-
model_proto = onnx.parser.parse_model(
299+
model = ir.from_onnx_text(
310300
"""
311301
<ir_version: 7, opset_import: [ "" : 23]>
312302
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
@@ -321,7 +311,6 @@ def test_version_convert_compatible(self):
321311
}
322312
"""
323313
)
324-
model = ir.serde.deserialize_model(model_proto)
325314
target_version = 24
326315
version_converter.convert_version(model, target_version=target_version)
327316

0 commit comments

Comments
 (0)