Skip to content

Commit 2be64a6

Browse files
committed
add a method to modify input index
1 parent b76fafa commit 2be64a6

File tree

1 file changed

+60
-73
lines changed

1 file changed

+60
-73
lines changed

tf2onnx/graph.py

Lines changed: 60 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -590,13 +590,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
590590
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
591591

592592
for name2 in onnx_node.input:
593-
if name2 not in self._input_to_node_name:
594-
self._input_to_node_name[name2] = set()
595-
self._input_to_node_name[name2].add(onnx_node.name)
596-
if self.parent_graph is not None:
597-
if name2 not in self.parent_graph._input_to_graph:
598-
self.parent_graph._input_to_graph[name2] = {}
599-
self.parent_graph._input_to_graph[name2][id(self)] = self
593+
self._register_input_name(name2, onnx_node)
600594

601595
if op_type in ["If", "Loop", "Scan"]:
602596
# we force the op containing inner graphs not skipped during conversion.
@@ -637,13 +631,7 @@ def append_node(self, node):
637631
self.set_dtype(name, output_dtypes[i])
638632
self.set_shape(name, output_shapes[i])
639633
for name in node.input:
640-
if name not in self._input_to_node_name:
641-
self._input_to_node_name[name] = set()
642-
self._input_to_node_name[name].add(node.name)
643-
if self.parent_graph is not None:
644-
if name not in self.parent_graph._input_to_graph:
645-
self.parent_graph._input_to_graph[name] = {}
646-
self.parent_graph._input_to_graph[name][id(self)] = self
634+
self._register_input_name(name, node)
647635

648636
def remove_node(self, node_name):
649637
"""Remove node in current graph."""
@@ -667,9 +655,8 @@ def remove_node(self, node_name):
667655
for op_input in node.input:
668656
if op_input not in self._input_to_node_name:
669657
raise RuntimeError(
670-
"Input %r of node %r not found." % (op_input, node.name))
671-
if node.name in self._input_to_node_name[op_input]:
672-
self._input_to_node_name[op_input].remove(node.name)
658+
"Input %r of node %r not found." % (op_input, node_name))
659+
self._unregister_input_name(op_input, node)
673660

674661
self._nodes.remove(node)
675662
node.graph = None
@@ -705,13 +692,7 @@ def reset_nodes(self, ops):
705692
else:
706693
inps = op.input
707694
for op_input in inps:
708-
if op_input not in self._input_to_node_name:
709-
self._input_to_node_name[op_input] = set()
710-
self._input_to_node_name[op_input].add(op.name)
711-
if self.parent_graph is not None:
712-
if op_input not in self.parent_graph._input_to_graph:
713-
self.parent_graph._input_to_graph[op_input] = {}
714-
self.parent_graph._input_to_graph[op_input][id(self)] = self
695+
self._register_input_name(op_input, op)
715696

716697
for n in self._order_sensitive_inputs:
717698
if n not in ops:
@@ -861,13 +842,7 @@ def set_node_by_name(self, node):
861842
for op_output in node.output:
862843
self._output_to_node_name[op_output] = node.name
863844
for name in node.input:
864-
if name not in self._input_to_node_name:
865-
self._input_to_node_name[name] = set()
866-
self._input_to_node_name[name].add(node.name)
867-
if self.parent_graph is not None:
868-
if name not in self.parent_graph._input_to_graph:
869-
self.parent_graph._input_to_graph[name] = {}
870-
self.parent_graph._input_to_graph[name][id(self)] = self
845+
self._register_input_name(name, node)
871846

872847
def change_node_name(self, node, new_name):
873848
"""Remove node in current graph."""
@@ -1233,10 +1208,7 @@ def remove_input(self, node, to_be_removed, i=None):
12331208

12341209
for i2, name in enumerate(node.input):
12351210
if name == to_be_removed:
1236-
if node.input[i2] in self._input_to_node_name:
1237-
to_ops = self._input_to_node_name[node.input[i2]]
1238-
if node.name in to_ops:
1239-
to_ops.remove(node.name)
1211+
self._unregister_input_name(node.input[i2], node)
12401212
del node.input[i2]
12411213
break
12421214
# don't remove output from parent since others might depend on it
@@ -1308,16 +1280,47 @@ def find_output_consumers(self, output_name):
13081280
if output_name in node.input:
13091281
nodes.append(node)
13101282

1311-
for node in self.get_nodes():
1312-
# find consumers in sub graphs,
1313-
# should we keep an index of nodes including
1314-
# a subgraphs?
1315-
body_graphs = node.get_body_graphs()
1316-
if body_graphs:
1317-
for g in body_graphs.values():
1318-
nodes.extend(g.find_output_consumers(output_name))
1283+
if output_name in self._input_to_graph:
1284+
for idg, g in self._input_to_graph[output_name].items():
1285+
nodes.extend(g.find_output_consumers(output_name))
1286+
else:
1287+
for node in self.get_nodes():
1288+
# find consumers in sub graphs,
1289+
# should we keep an index of nodes including
1290+
# a subgraphs?
1291+
body_graphs = node.get_body_graphs()
1292+
if body_graphs:
1293+
for g in body_graphs.values():
1294+
ext = g.find_output_consumers(output_name)
1295+
if len(ext) > 0:
1296+
raise RuntimeError(
1297+
"Inconsistency in _input_to_graph.")
1298+
# nodes.extend(ext)
13191299
return nodes
13201300

1301+
def _register_input_name(self, input_name, node, only_graph=False):
1302+
if not only_graph:
1303+
if input_name not in self._input_to_node_name:
1304+
self._input_to_node_name[input_name] = set()
1305+
self._input_to_node_name[input_name].add(node.name)
1306+
if self.parent_graph is not None:
1307+
if input_name not in self.parent_graph._input_to_graph:
1308+
self.parent_graph._input_to_graph[input_name] = {}
1309+
self.parent_graph._input_to_graph[input_name][id(self)] = self
1310+
self.parent_graph._register_input_name(input_name, node, only_graph=True)
1311+
1312+
def _unregister_input_name(self, input_name, node, only_graph=False):
1313+
node_name = node.name
1314+
if not only_graph:
1315+
if node_name in self._input_to_node_name[input_name]:
1316+
if node_name in self._input_to_node_name[input_name]:
1317+
self._input_to_node_name[input_name].remove(node_name)
1318+
if (self.parent_graph is not None and
1319+
input_name in self.parent_graph._input_to_graph and
1320+
id(self) in self.parent_graph._input_to_graph[input_name]):
1321+
del self.parent_graph._input_to_graph[input_name][id(self)]
1322+
self.parent_graph._unregister_input_name(input_name, node, only_graph=True)
1323+
13211324
def replace_all_inputs(self, ops, old_input, new_input, keep_ops=False):
13221325
"""
13231326
Replace all inputs pointing to old_input with new_input.
@@ -1327,10 +1330,6 @@ def replace_all_inputs(self, ops, old_input, new_input, keep_ops=False):
13271330
return
13281331
if new_input not in self._input_to_node_name:
13291332
self._input_to_node_name[new_input] = set()
1330-
if self.parent_graph is not None:
1331-
if new_input not in self.parent_graph._input_to_graph:
1332-
self.parent_graph._input_to_graph[new_input] = {}
1333-
self.parent_graph._input_to_graph[new_input][id(self)] = self
13341333

13351334
if keep_ops and ops is not None:
13361335
pass
@@ -1344,22 +1343,22 @@ def replace_all_inputs(self, ops, old_input, new_input, keep_ops=False):
13441343
continue
13451344
if old_input in node.input and new_input in node.output:
13461345
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)
1347-
self._input_to_node_name[new_input].add(node.name)
1348-
if self.parent_graph is not None:
1349-
if new_input not in self.parent_graph._input_to_graph:
1350-
self.parent_graph._input_to_graph[new_input] = {}
1351-
self.parent_graph._input_to_graph[new_input][id(self)] = self
1346+
self._register_input_name(new_input, node)
13521347

13531348
for i, input_name in enumerate(node.input):
13541349
if input_name == old_input:
13551350
self.replace_input(node, node.input[i], new_input, i)
13561351

1357-
for node in self.get_nodes():
1358-
# modify references in sub graphs
1359-
body_graphs = node.get_body_graphs()
1360-
if body_graphs:
1361-
for g in body_graphs.values():
1362-
g.replace_all_inputs(g.get_nodes(), old_input, new_input, keep_ops=keep_ops)
1352+
if old_input in self._input_to_graph:
1353+
for idg, g in self._input_to_graph[old_input].items():
1354+
g.replace_all_inputs(g.get_nodes(), old_input, new_input, keep_ops=keep_ops)
1355+
#~ else:
1356+
#~ for node in self.get_nodes():
1357+
#~ # modify references in sub graphs
1358+
#~ body_graphs = node.get_body_graphs()
1359+
#~ if body_graphs:
1360+
#~ for g in body_graphs.values():
1361+
#~ g.replace_all_inputs(g.get_nodes(), old_input, new_input, keep_ops=keep_ops)
13631362

13641363
def replace_input(self, node, old_input, new_input, i=None):
13651364
"""Replace one input in a node."""
@@ -1382,14 +1381,7 @@ def replace_input(self, node, old_input, new_input, i=None):
13821381
# A node may take twice the same entry.
13831382
to_ops.remove(node.name)
13841383

1385-
if new_input not in self._input_to_node_name:
1386-
self._input_to_node_name[new_input] = set()
1387-
self._input_to_node_name[new_input].add(node.name)
1388-
if self.parent_graph is not None:
1389-
if new_input not in self.parent_graph._input_to_graph:
1390-
self.parent_graph._input_to_graph[new_input] = {}
1391-
self.parent_graph._input_to_graph[new_input][id(self)] = self
1392-
1384+
self._register_input_name(new_input, node)
13931385
return is_replaced
13941386

13951387
def replace_inputs(self, node, new_inputs):
@@ -1405,13 +1397,8 @@ def replace_inputs(self, node, new_inputs):
14051397

14061398
for input_name in new_inputs:
14071399
assert isinstance(input_name, six.text_type)
1408-
if input_name not in self._input_to_node_name:
1409-
self._input_to_node_name[input_name] = set()
1410-
self._input_to_node_name[input_name].add(node.name)
1411-
if self.parent_graph is not None:
1412-
if input_name not in self.parent_graph._input_to_graph:
1413-
self.parent_graph._input_to_graph[input_name] = {}
1414-
self.parent_graph._input_to_graph[input_name][id(self)] = self
1400+
self._register_input_name(input_name, node)
1401+
14151402
node.input = new_inputs
14161403
return True
14171404

0 commit comments

Comments
 (0)