Skip to content

Add from_onnx_text function to convert ONNX text to IR model #2291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/ir/ir_api/core.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ir.load
ir.save
ir.from_proto
ir.from_onnx_text
ir.to_proto
ir.tensor
ir.node
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"TensorProtoTensor",
# Conversion functions
"from_proto",
"from_onnx_text",
"to_proto",
# Convenience constructors
"tensor",
Expand Down Expand Up @@ -144,7 +145,7 @@
TypeProtocol,
ValueProtocol,
)
from onnxscript.ir.serde import TensorProtoTensor, from_proto, to_proto
from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto


def __set_module() -> None:
Expand Down
7 changes: 2 additions & 5 deletions onnxscript/ir/passes/common/inliner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Callable, Sequence

import onnx
from onnx import parser

from onnxscript import ir
from onnxscript.ir.passes.common import inliner
Expand Down Expand Up @@ -44,14 +43,12 @@ def _check(
self, input_model: str, expected_model: str, renameable: Sequence[str] | None = None
) -> None:
name_check = _name_checker(renameable)
model_proto = parser.parse_model(input_model)
model_ir = ir.serde.deserialize_model(model_proto)
model_ir = ir.from_onnx_text(input_model)
inliner.InlinePass()(model_ir)
proto = ir.serde.serialize_model(model_ir)
text = onnx.printer.to_text(proto)
print(text)
expected_proto = parser.parse_model(expected_model)
expected_ir = ir.serde.deserialize_model(expected_proto)
expected_ir = ir.from_onnx_text(expected_model)
self.assertEqual(len(model_ir.graph), len(expected_ir.graph))
for node, expected_node in zip(model_ir.graph, expected_ir.graph):
# TODO: handle node renaming
Expand Down
10 changes: 10 additions & 0 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"TensorProtoTensor",
# Deserialization
"from_proto",
"from_onnx_text",
"deserialize_attribute",
"deserialize_dimension",
"deserialize_function",
Expand Down Expand Up @@ -190,6 +191,15 @@ def from_proto(proto: object) -> object:
)


def from_onnx_text(model_text: str, /) -> _core.Model:
"""Convert the ONNX textual representation to an IR model.

Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
"""
proto = onnx.parser.parse_model(model_text)
return deserialize_model(proto)


@typing.overload
def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap]
@typing.overload
Expand Down
10 changes: 2 additions & 8 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,10 @@
from onnxscript.optimizer import _constant_folding


def _create_model(model_text: str) -> ir.Model:
"""Create a model from the given text."""
model = onnx.parser.parse_model(model_text)
return ir.serde.deserialize_model(model)


class FoldConstantsTest(unittest.TestCase):
def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs):
if isinstance(model, str):
model = _create_model(model)
model = ir.from_onnx_text(model)
_constant_folding.fold_constants(
model, onnx_shape_inference=onnx_shape_inference, **kwargs
)
Expand Down Expand Up @@ -552,7 +546,7 @@ def test_large_transpose(self):
z = MatMul (x, wt)
}
"""
model = _create_model(model_text)
model = ir.from_onnx_text(model_text)
w = model.graph.initializers["w"]
w.shape = ir.Shape([512, 256])
w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32))
Expand Down
4 changes: 1 addition & 3 deletions onnxscript/rewriter/no_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.
import unittest

import onnx.parser
import parameterized

from onnxscript import ir
Expand All @@ -11,8 +10,7 @@

class NoOpTest(unittest.TestCase):
def _check(self, model_text: str) -> None:
model_proto = onnx.parser.parse_model(model_text)
model = ir.serde.deserialize_model(model_proto)
model = ir.from_onnx_text(model_text)
count = no_op.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(model.graph[-1].op_type, "Identity")
Expand Down
31 changes: 10 additions & 21 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import unittest

import onnx.defs
import onnx.parser

from onnxscript import ir, version_converter

Expand Down Expand Up @@ -43,7 +42,7 @@ def test_upstream_coverage(self):
self.assertIn((name, upgrade_version), op_upgrades)

def test_version_convert_non_standard_onnx_domain(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "local" : 1]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
Expand All @@ -58,7 +57,6 @@ def test_version_convert_non_standard_onnx_domain(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
self.assertEqual(model.graph.node(4).op_type, "GridSample")
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")

Expand All @@ -76,7 +74,7 @@ def test_version_convert_non_standard_onnx_domain(self):

class VersionConverter18to17Test(unittest.TestCase):
def test_version_convert_compatible(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
Expand All @@ -91,14 +89,13 @@ def test_version_convert_compatible(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 17
version_converter.convert_version(model, target_version=target_version)


class VersionConverter18to19Test(unittest.TestCase):
def test_version_convert_compatible(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
Expand All @@ -113,7 +110,6 @@ def test_version_convert_compatible(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 19
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -127,7 +123,7 @@ def test_version_convert_compatible(self):

class VersionConverter19to20Test(unittest.TestCase):
def test_version_convert_compatible(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x) => (float[4, 257, 64, 2] output)
Expand All @@ -140,7 +136,6 @@ def test_version_convert_compatible(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 20
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -155,7 +150,7 @@ def test_version_convert_compatible(self):
self.assertEqual(len(model.graph.node(3).inputs), 2)

def test_version_convert_gridsample_linear(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
Expand All @@ -170,7 +165,6 @@ def test_version_convert_gridsample_linear(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
self.assertEqual(model.graph.node(4).op_type, "GridSample")
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bilinear")

Expand All @@ -186,7 +180,7 @@ def test_version_convert_gridsample_linear(self):
self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear")

def test_version_convert_gridsample_cubic(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output)
Expand All @@ -201,7 +195,6 @@ def test_version_convert_gridsample_cubic(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
self.assertEqual(model.graph.node(4).op_type, "GridSample")
self.assertEqual(model.graph.node(4).attributes["mode"].value, "bicubic")

Expand All @@ -217,7 +210,7 @@ def test_version_convert_gridsample_cubic(self):
self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic")

def test_version_convert_inline(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 8, opset_import: [ "" : 18]>
agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 257, 64, 2] output)
Expand All @@ -236,7 +229,6 @@ def test_version_convert_inline(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 20
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -254,7 +246,7 @@ def test_version_convert_inline(self):

class VersionConverter20to21Test(unittest.TestCase):
def test_version_groupnorm(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[1, 4, 512, 512] input_x, float[2] scale, float[2] bias) => (float[4, 512, 512] output)
Expand All @@ -265,7 +257,6 @@ def test_version_groupnorm(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 21
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -285,7 +276,7 @@ def test_version_groupnorm(self):
self.assertEqual(model.graph.node(9).version, 21)

def test_version_groupnorm_no_bias(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[1, 4, 512, 512] input_x, float[2] scale) => (float[4, 512, 512] output)
Expand All @@ -296,7 +287,6 @@ def test_version_groupnorm_no_bias(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 21
version_converter.convert_version(model, target_version=target_version)

Expand All @@ -306,7 +296,7 @@ def test_version_groupnorm_no_bias(self):

class VersionConverter23to24Test(unittest.TestCase):
def test_version_convert_compatible(self):
model_proto = onnx.parser.parse_model(
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 23]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
Expand All @@ -321,7 +311,6 @@ def test_version_convert_compatible(self):
}
"""
)
model = ir.serde.deserialize_model(model_proto)
target_version = 24
version_converter.convert_version(model, target_version=target_version)

Expand Down
Loading