Skip to content

Commit d2fab20

Browse files
authored
Remove legacy_ir usage in testutil (#2451)
Also clean up the function rewrite rules. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 727210b commit d2fab20

File tree

10 files changed

+6
-856
lines changed

10 files changed

+6
-856
lines changed

onnxscript/rewriter/onnxruntime/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from __future__ import annotations
77

8-
from typing import Any, Sequence
8+
from typing import Sequence
99

1010
import onnx
1111

@@ -16,11 +16,8 @@
1616
__all__ = [
1717
"rewrite",
1818
"ORT_PATTERN_REWRITE_RULES",
19-
"ORT_FUNCTION_REWRITE_RULES",
2019
]
2120

22-
ORT_FUNCTION_REWRITE_RULES: list[Any] = []
23-
2421

2522
def rewrite(
2623
model_proto: onnx.ModelProto,

tests/common/testutils.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
import numpy as np
1111
import onnx
12+
import onnx_ir as ir
1213
import onnxruntime
1314
import torch
1415

1516
from onnxscript import optimizer
16-
from onnxscript._legacy_ir import visitor
1717
from onnxscript.rewriter import onnxruntime as ort_rewriter
1818
from onnxscript.utils import evaluation_utils
1919

@@ -39,20 +39,6 @@ def wrapper(self, *args, **kwargs):
3939
return skip_dec
4040

4141

42-
class OpTypeAnalysisVisitor(visitor.ProtoVisitorCore):
43-
def __init__(self):
44-
super().__init__()
45-
self.op_types = set()
46-
47-
def visit_model(self, model: onnx.ModelProto):
48-
self.op_types = set()
49-
super().visit_model(model)
50-
51-
def process_node(self, node: onnx.NodeProto):
52-
self.op_types.add((node.domain, node.op_type, getattr(node, "overload", "")))
53-
return super().process_node(node)
54-
55-
5642
def test_onnxruntime_rewrite(
5743
model_basename: str,
5844
model_count: int,
@@ -84,10 +70,11 @@ def test_onnxruntime_rewrite(
8470
# onnx.save(rewritten, model_dir / f"{model_name}_opt.onnx")
8571

8672
# Check expected operator is found.
87-
optype_analysis = OpTypeAnalysisVisitor()
88-
optype_analysis.visit_model(rewritten)
73+
op_types = set()
74+
for node in ir.from_proto(model).graph.all_nodes():
75+
op_types.add((node.domain, node.op_type, node.overload))
8976
for domain, op_type, overload in expected_optypes:
90-
if (domain, op_type, overload) not in optype_analysis.op_types:
77+
if (domain, op_type, overload) not in op_types:
9178
raise AssertionError(
9279
f"Expected op type {domain}:{op_type}:{overload} not found in rewritten model."
9380
)

tools/diagnostics/gen_diagnostics.py

Lines changed: 0 additions & 257 deletions
This file was deleted.

tools/diagnostics/gen_diagnostics.sh

Lines changed: 0 additions & 16 deletions
This file was deleted.

tools/diagnostics/sarif/code-gen-hints.json

Lines changed: 0 additions & 10 deletions
This file was deleted.

0 commit comments

Comments
 (0)