Skip to content

Commit 1de2a95

Browse files
committed
First step to add forward indexes.
1 parent 36d7413 commit 1de2a95

File tree

4 files changed

+66
-9
lines changed

4 files changed

+66
-9
lines changed

tests/test_optimizers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,6 @@ def test_duplicated_duplicated_attributes(self):
912912
op_type="ReduceSum", remaining_op_num=2)
913913

914914
def _check_initializer_num(self, graph_proto, num):
915-
print(len(graph_proto.initializer))
916915
return num == len(graph_proto.initializer)
917916

918917
def test_duplicated_duplicated_constant(self):

tf2onnx/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
MICROSOFT_DOMAIN = "com.microsoft"
1616

1717
# Default opset version for onnx domain
18-
PREFERRED_OPSET = 8
18+
PREFERRED_OPSET = 11
1919

2020
# Default opset for custom ops
2121
TENSORFLOW_OPSET = helper.make_opsetid("ai.onnx.converters.tensorflow", 1)

tf2onnx/graph.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
418418
self._nodes = []
419419
self._nodes_by_name = {}
420420
self._output_to_node_name = {}
421+
self._input_to_node_name = {}
421422
self.shapes = {}
422423
self.graph_name = graph_name or "tf2onnx"
423424
self._is_subgraph = is_subgraph
@@ -442,7 +443,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
442443

443444
ops = [Node(node, self) for node in nodes]
444445
self.reset_nodes(ops)
445-
446+
446447
if not is_subgraph:
447448
# add identity node after each output, in case it is renamed during conversion.
448449
for o in self.outputs:
@@ -569,6 +570,11 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
569570

570571
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
571572

573+
for name in onnx_node.input:
574+
if name not in self._input_to_node_name:
575+
self._input_to_node_name[name] = set()
576+
self._input_to_node_name[name].add(onnx_node.name)
577+
572578
if op_type in ["If", "Loop", "Scan"]:
573579
# we force the op containing inner graphs not skipped during conversion.
574580
skip_conversion = False
@@ -606,6 +612,10 @@ def append_node(self, node):
606612
self._output_to_node_name[name] = node.name
607613
self.set_dtype(name, output_dtypes[i])
608614
self.set_shape(name, output_shapes[i])
615+
for name in node.input:
616+
if name not in self._input_to_node_name:
617+
self._input_to_node_name[name] = set()
618+
self._input_to_node_name[name].add(node.name)
609619

610620
def remove_node(self, node_name):
611621
"""Remove node in current graph."""
@@ -626,6 +636,13 @@ def remove_node(self, node_name):
626636
if op_output in self._dtypes:
627637
del self._dtypes[op_output]
628638

639+
for op_input in node.input:
640+
if op_input not in self._input_to_node_name:
641+
raise RuntimeError(
642+
"Input %r of node %r not found." % (op_input, node.name))
643+
if node.name in self._input_to_node_name[op_input]:
644+
self._input_to_node_name[op_input].remove(node.name)
645+
629646
self._nodes.remove(node)
630647
node.graph = None
631648

@@ -649,16 +666,26 @@ def reset_nodes(self, ops):
649666
self.contained_graphs = remained_sub_graphs
650667
self._nodes_by_name = {op.name: op for op in ops}
651668
self._output_to_node_name = {}
669+
self._input_to_node_name = {}
652670
for op in ops:
653671
for op_output in op.output:
654672
self._output_to_node_name[op_output] = op.name
673+
for op_input in op.input:
674+
if op_input not in self._input_to_node_name:
675+
self._input_to_node_name[op_input] = set()
676+
self._input_to_node_name[op_input].add(op.name)
655677

656678
for n in self._order_sensitive_inputs:
657679
if n not in ops:
658680
self._order_sensitive_inputs.remove(n)
659681
for o in self.outputs:
660682
if o not in self._output_to_node_name:
661-
raise ValueError("graph output " + o + " not exist")
683+
raise ValueError("graph output %r not exist" % o)
684+
for i in self.inputs:
685+
if i.name.startswith('Placeholder'):
686+
continue
687+
if i.name not in self._input_to_node_name:
688+
raise ValueError("graph input %r not exist in graph." % i.name)
662689

663690
self._dtypes = remained_dtypes
664691
self._output_shapes = remained_shapes
@@ -775,6 +802,14 @@ def get_node_by_output_in_current_graph(self, output):
775802
ret = self._nodes_by_name.get(name)
776803
return ret
777804

805+
def get_node_by_input_in_current_graph(self, input):
806+
"""Get nodes by node input id."""
807+
names = self._output_to_node_name.get(input)
808+
ret = None
809+
if name:
810+
ret = [self._nodes_by_name.get(name) for name in names]
811+
return ret
812+
778813
def get_node_by_name(self, name):
779814
"""Get node by name."""
780815
ret = self._nodes_by_name.get(name)
@@ -785,6 +820,10 @@ def set_node_by_name(self, node):
785820
self._nodes_by_name[node.name] = node
786821
for op_output in node.output:
787822
self._output_to_node_name[op_output] = node.name
823+
for name in node.input:
824+
if name not in self._input_to_node_name:
825+
self._input_to_node_name[name] = set()
826+
self._input_to_node_name[name].add(node.name)
788827

789828
def change_node_name(self, node, new_name):
790829
"""Remove node in current graph."""
@@ -1210,35 +1249,54 @@ def find_output_consumers(self, output_name):
12101249
nodes.extend(g.find_output_consumers(output_name))
12111250
return nodes
12121251

1213-
@staticmethod
1214-
def replace_all_inputs(ops, old_input, new_input):
1252+
def replace_all_inputs(self, ops, old_input, new_input):
12151253
"""Replace all inputs pointing to old_input with new_input."""
12161254
if old_input == new_input:
12171255
return
1256+
if new_input not in self._input_to_node_name:
1257+
self._input_to_node_name[new_input] = set()
1258+
1259+
to_ops = self._input_to_node_name.get(old_input, None)
1260+
if to_ops is None:
1261+
# This means old_input is a final output.
1262+
to_ops = set()
12181263

12191264
for node in ops:
12201265
if old_input in node.input and new_input in node.output:
12211266
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)
1267+
self._input_to_node_name[new_input].add(node.name)
12221268

12231269
for i, input_name in enumerate(node.input):
12241270
if input_name == old_input:
12251271
node.input[i] = new_input
1272+
if node.name not in to_ops:
1273+
raise RuntimeError(
1274+
"Unable to replace %r by %r. Node %r is not using input %r." % (
1275+
old_input, new_input, node.name, old_input))
1276+
12261277

12271278
# modify references in sub graphs
12281279
body_graphs = node.get_body_graphs()
12291280
if body_graphs:
12301281
for g in body_graphs.values():
12311282
g.replace_all_inputs(g.get_nodes(), old_input, new_input)
12321283

1233-
@staticmethod
1234-
def replace_input(node, old_input, new_input):
1284+
def replace_input(self, node, old_input, new_input):
12351285
"""Replace node."""
12361286
assert isinstance(node, Node) and isinstance(old_input, six.text_type) and isinstance(new_input, six.text_type)
12371287
is_replaced = False
12381288
for i, input_name in enumerate(node.input):
12391289
if input_name == old_input:
12401290
node.input[i] = new_input
12411291
is_replaced = True
1292+
1293+
to_ops = self._input_to_node_name.get(old_input, None)
1294+
if to_ops is not None:
1295+
to_ops.remove(node.name)
1296+
if new_input not in self._input_to_node_name:
1297+
self._input_to_node_name[new_input] = set()
1298+
self._input_to_node_name[new_input].add(node.name)
1299+
12421300
return is_replaced
12431301

12441302
def _extract_sub_graph_nodes(self, dest_node, input_checker=None):

tf2onnx/onnx_opset/controlflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def version_7(cls, ctx, node, **kwargs):
524524
maximum_iterations_name = node.input[1]
525525
maximum_iterations = node.inputs[1].get_tensor_value()
526526
if maximum_iterations == -1:
527-
maximum_iterations = sys.maxsize
527+
maximum_iterations = np.iinfo(dtype_loop).max
528528
consumers = ctx.find_output_consumers(maximum_iterations_name)
529529
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
530530
if len(external_consumers) == 0:

0 commit comments

Comments
 (0)