Skip to content

Commit 99cdedd

Browse files
authored
[IR] Normalize "ai.onnx" domain to "" (#2283)
TODO: Also update opset_imports to handle the "ai.onnx" key. Fix #2280
1 parent 0e09e58 commit 99cdedd

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

onnxscript/ir/_core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,11 @@ def _short_tensor_str_for_node(x: Value) -> str:
12781278
return "{...}"
12791279

12801280

1281+
def _normalize_domain(domain: str) -> str:
1282+
"""Normalize 'ai.onnx' to ''"""
1283+
return "" if domain == "ai.onnx" else domain
1284+
1285+
12811286
class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
12821287
"""IR Node.
12831288
@@ -1328,6 +1333,7 @@ def __init__(
13281333
13291334
Args:
13301335
domain: The domain of the operator. For onnx operators, this is an empty string.
1336+
When it is "ai.onnx", it is normalized to "".
13311337
op_type: The name of the operator.
13321338
inputs: The input values. When an input is ``None``, it is an empty input.
13331339
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
@@ -1350,7 +1356,7 @@ def __init__(
13501356
ValueError: If an output value has a producer set already, when outputs is specified.
13511357
"""
13521358
self._name = name
1353-
self._domain: str = domain
1359+
self._domain: str = _normalize_domain(domain)
13541360
self._op_type: str = op_type
13551361
# NOTE: Make inputs immutable with the assumption that they are not mutated
13561362
# very often. This way all mutations can be tracked.
@@ -1482,7 +1488,7 @@ def domain(self) -> str:
14821488

14831489
@domain.setter
14841490
def domain(self, value: str) -> None:
1485-
self._domain = value
1491+
self._domain = _normalize_domain(value)
14861492

14871493
@property
14881494
def version(self) -> int | None:
@@ -2885,7 +2891,7 @@ def domain(self) -> str:
28852891

28862892
@domain.setter
28872893
def domain(self, value: str) -> None:
2888-
self._domain = value
2894+
self._domain = _normalize_domain(value)
28892895

28902896
@property
28912897
def overload(self) -> str:

onnxscript/ir/_core_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,17 @@ def test_successors(self):
850850
def test_successors_are_unique(self):
851851
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))
852852

853+
def test_domain_normalizes_ai_onnx(self):
854+
# Node domain is always normalized to "" if it is "ai.onnx"
855+
node = _core.Node("ai.onnx", "TestOp", inputs=())
856+
self.assertEqual(node.domain, "")
857+
858+
node.domain = ""
859+
self.assertEqual(node.domain, "")
860+
861+
node.domain = "ai.onnx"
862+
self.assertEqual(node.domain, "")
863+
853864
# TODO(justinchuby): Test all methods
854865

855866

onnxscript/optimizer/_constant_folding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def is_constant_op(node: ir.Node) -> bool:
6464

6565
def _process_constant_node(node: ir.Node) -> None:
6666
"""Sets const_value of output value of a Constant op node."""
67-
if node.op_type != "Constant" or node.domain not in {"", "ai.onnx"}:
67+
if node.op_type != "Constant" or node.domain != "":
6868
return
6969
if len(node.attributes) != 1:
7070
return

onnxscript/version_converter/_version_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: boo
229229
def process_node(
230230
self, node: ir.Node, opset_version: int, up_conversion: bool = True
231231
) -> Replacement | None:
232-
if node.domain not in {"", "ai.onnx"}:
232+
if node.domain != "":
233233
return None
234234
adapter = registry.lookup_adapters(
235235
node.domain, node.op_type, opset_version, up_conversion

0 commit comments

Comments
 (0)