Skip to content

Commit 6570acd

Browse files
committed
refactoring
1 parent 2be64a6 commit 6570acd

17 files changed

+41
-59
lines changed

benchmark/conversion_time.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def profile(profiler="pyinstrument", name="EfficientNetB2", show_all=False,
6363
graph_def, model = create(name, module)
6464
print("profile(%r, %r, %r)" % (profiler, name, module))
6565
if profiler == "spy":
66+
# py-spy record -r 10 -o profile.svg -- python conversion_time.py spy
6667
convert(graph_def, model)
6768
elif profiler == "pyinstrument":
6869
from pyinstrument import Profiler

tests/test_internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_rewrite_subgraph(self):
139139
op_name = utils.make_name("ReplacedOp")
140140
out_name = utils.port_name(op_name)
141141
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
142-
g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
142+
g.replace_all_inputs(None, output_node.output[0], new_node.output[0]) # ops
143143
for n in set(match.get_nodes()):
144144
g.remove_node(n.name)
145145
g.topological_sort(ops)

tf2onnx/graph.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,27 +1281,15 @@ def find_output_consumers(self, output_name):
12811281
nodes.append(node)
12821282

12831283
if output_name in self._input_to_graph:
1284-
for idg, g in self._input_to_graph[output_name].items():
1284+
for _, g in self._input_to_graph[output_name].items():
12851285
nodes.extend(g.find_output_consumers(output_name))
1286-
else:
1287-
for node in self.get_nodes():
1288-
# find consumers in sub graphs,
1289-
# should we keep an index of nodes including
1290-
# a subgraphs?
1291-
body_graphs = node.get_body_graphs()
1292-
if body_graphs:
1293-
for g in body_graphs.values():
1294-
ext = g.find_output_consumers(output_name)
1295-
if len(ext) > 0:
1296-
raise RuntimeError(
1297-
"Inconsistency in _input_to_graph.")
1298-
# nodes.extend(ext)
12991286
return nodes
13001287

13011288
def _register_input_name(self, input_name, node, only_graph=False):
1289+
"Register node taking a specific input."
13021290
if not only_graph:
13031291
if input_name not in self._input_to_node_name:
1304-
self._input_to_node_name[input_name] = set()
1292+
self._input_to_node_name[input_name] = set()
13051293
self._input_to_node_name[input_name].add(node.name)
13061294
if self.parent_graph is not None:
13071295
if input_name not in self.parent_graph._input_to_graph:
@@ -1310,6 +1298,7 @@ def _register_input_name(self, input_name, node, only_graph=False):
13101298
self.parent_graph._register_input_name(input_name, node, only_graph=True)
13111299

13121300
def _unregister_input_name(self, input_name, node, only_graph=False):
1301+
"Unregister node taking a specific input."
13131302
node_name = node.name
13141303
if not only_graph:
13151304
if node_name in self._input_to_node_name[input_name]:
@@ -1350,15 +1339,8 @@ def replace_all_inputs(self, ops, old_input, new_input, keep_ops=False):
13501339
self.replace_input(node, node.input[i], new_input, i)
13511340

13521341
if old_input in self._input_to_graph:
1353-
for idg, g in self._input_to_graph[old_input].items():
1342+
for _, g in self._input_to_graph[old_input].items():
13541343
g.replace_all_inputs(g.get_nodes(), old_input, new_input, keep_ops=keep_ops)
1355-
#~ else:
1356-
#~ for node in self.get_nodes():
1357-
#~ # modify references in sub graphs
1358-
#~ body_graphs = node.get_body_graphs()
1359-
#~ if body_graphs:
1360-
#~ for g in body_graphs.values():
1361-
#~ g.replace_all_inputs(g.get_nodes(), old_input, new_input, keep_ops=keep_ops)
13621344

13631345
def replace_input(self, node, old_input, new_input, i=None):
13641346
"""Replace one input in a node."""

tf2onnx/onnx_opset/controlflow.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class TensorListStack:
492492
def version_7(cls, ctx, node, **kwargs):
493493
if node.inputs[0].is_while():
494494
ctx.remove_node(node.name)
495-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], node.input[0], keep_ops=False)
495+
ctx.replace_all_inputs(None, node.output[0], node.input[0], keep_ops=False) # ctx.get_nodes()
496496

497497

498498
@tf_op(["While", "StatelessWhile"])
@@ -577,12 +577,11 @@ def version_7(cls, ctx, node, **kwargs):
577577
del output_names[idx]
578578
del body.outputs[idx]
579579

580-
removed_scan_outputs = {}
581580
# remove tensor array that are passed in to the loop
582581
for idx, n in reversed(to_remove):
583582
ctx.remove_node(n.name)
584583
# make the node output bad
585-
ctx.replace_all_inputs(ctx.get_nodes(), n.output[0], "@@ALLOC", keep_ops=False)
584+
ctx.replace_all_inputs(None, n.output[0], "@@ALLOC", keep_ops=False) # ctx.get_nodes()
586585
del body.func_inputs[idx]
587586
del cond_graph.func_inputs[idx]
588587
del tf_while_inputs[idx]
@@ -618,7 +617,7 @@ def version_7(cls, ctx, node, **kwargs):
618617

619618
# shift output consumers
620619
for k, v in output_map.items():
621-
ctx.replace_all_inputs(ctx.get_nodes(), k, v, keep_ops=False)
620+
ctx.replace_all_inputs(None, k, v, keep_ops=False) # ctx.get_nodes()
622621

623622
wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
624623
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs)

tf2onnx/onnx_opset/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,4 +695,4 @@ def atan2(y, x):
695695
"Add", inputs=[atan_node.output[0], pi_part.output[0]],
696696
op_name_scope=node.name + 'all',
697697
shapes=[shape], dtypes=[onnx_dtype])
698-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0], keep_ops=False)
698+
ctx.replace_all_inputs(None, node.output[0], last_node.output[0], keep_ops=False) # ctx.get_nodes()

tf2onnx/onnx_opset/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def version_1(cls, ctx, node, **kwargs):
3030
# if identity has a const as input, remove it
3131
input_name = node.input[0]
3232
output_name = node.output[0]
33-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name, keep_ops=False)
33+
ctx.replace_all_inputs(None, output_name, input_name, keep_ops=False) # ctx.get_nodes()
3434
ctx.remove_node(node.name)
3535

3636

tf2onnx/onnx_opset/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,4 @@ def version_10(cls, ctx, node, **kwargs):
7878
"DequantizeLinear", [new_node.output[0], pb_scale.name, zero_point.name],
7979
op_name_scope=node.name, attr={"axis": axis},
8080
shapes=[shape], dtypes=[dtype])
81-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0], keep_ops=False)
81+
ctx.replace_all_inputs(None, node.output[0], last_node.output[0], keep_ops=False) # ctx.get_nodes()

tf2onnx/onnx_opset/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def make_sigmoid(i, w, b):
153153
h_node = ctx.make_node("Mul", [co_node.output[0], o])
154154

155155
def replace_output(old_output, new_output):
156-
ctx.replace_all_inputs(ctx.get_nodes(), old_output, new_output, keep_ops=False)
156+
ctx.replace_all_inputs(None, old_output, new_output, keep_ops=False) # ctx.get_nodes()
157157
ctx.copy_dtype(old_output, new_output)
158158
ctx.copy_shape(old_output, new_output)
159159

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def version_1(cls, ctx, node, **kwargs):
119119
# if identity has a const as input, remove it
120120
input_name = node.input[0]
121121
output_name = node.output[0]
122-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name, keep_ops=False)
122+
ctx.replace_all_inputs(None, output_name, input_name, keep_ops=False) # ctx.get_nodes()
123123
ctx.remove_node(node.name)
124124

125125

@@ -129,7 +129,7 @@ class IdentityN:
129129
def version_1(cls, ctx, node, **kwargs):
130130
ctx.remove_node(node.name)
131131
for input_name, output_name in zip(node.input, node.output):
132-
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name, keep_ops=False)
132+
ctx.replace_all_inputs(None, output_name, input_name, keep_ops=False) # ctx.get_nodes()
133133

134134

135135
@tf_op("Reshape")
@@ -1051,7 +1051,7 @@ def version_1(cls, ctx, node, **kwargs):
10511051
# concat all unqueezes
10521052
concat = ctx.make_node("Concat", inputs, op_name_scope=node.name, attr={"axis": axis},
10531053
shapes=shapes, dtypes=dtypes)
1054-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], concat.output[0], keep_ops=False)
1054+
ctx.replace_all_inputs(None, node.output[0], concat.output[0], keep_ops=False) # ctx.get_nodes()
10551055

10561056

10571057
@tf_op("Unpack")

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def merge_duplicated_transposes(self):
126126
transpose_out = transposes[0].output[0]
127127
for node in transposes[1:]:
128128
old_transpose_out = node.output[0]
129-
graph.replace_all_inputs(graph.get_nodes(), old_transpose_out, transpose_out, keep_ops=False)
129+
graph.replace_all_inputs(None, old_transpose_out, transpose_out, keep_ops=False) # graph.get_nodes()
130130

131131
# dangling transpose nodes can be deleted
132132
graph.delete_unused_nodes(graph.outputs)
@@ -209,7 +209,7 @@ def _handle_node_having_branches(self, node):
209209
for n in input_transposes:
210210
n_input = n.input[0]
211211
utils.make_sure(len(n.output) == 1, "only expect single output")
212-
self._g.replace_all_inputs(self._g.get_nodes(), n.output[0], n_input, keep_ops=False)
212+
self._g.replace_all_inputs(None, n.output[0], n_input, keep_ops=False) # self._g.get_nodes()
213213
self._g.remove_node(n.name)
214214

215215
utils.make_sure(len(node.output) == 1, "only expect single output")
@@ -220,7 +220,7 @@ def _handle_node_having_branches(self, node):
220220
for n in output_transposes:
221221
n_input = n.input[0]
222222
utils.make_sure(len(n.output) == 1, "only expect single output")
223-
self._g.replace_all_inputs(self._g.get_nodes(), n.output[0], n_input, keep_ops=False)
223+
self._g.replace_all_inputs(None, n.output[0], n_input, keep_ops=False) # self._g.get_nodes()
224224
self._g.remove_node(n.name)
225225

226226
shape = self._g.get_shape(node.output[0])
@@ -249,7 +249,7 @@ def _switch_transpose_and_node(self, node, trans):
249249

250250
input_index = self._get_input_index_for_trans(node, trans)
251251

252-
self._g.replace_all_inputs(self._g.get_nodes(), node.output[0], trans.output[0], keep_ops=False)
252+
self._g.replace_all_inputs(None, node.output[0], trans.output[0], keep_ops=False) # self._g.get_nodes()
253253
self._g.replace_input(node, node.input[input_index], trans.input[0], input_index)
254254
self._g.replace_input(trans, trans.input[0], node.output[0], 0)
255255

@@ -288,7 +288,7 @@ def _handle_nhwc_tranpose(self, trans):
288288
return False
289289

290290
def _remove_useless_tranpose(self, trans):
291-
self._g.replace_all_inputs(self._g.get_nodes(), trans.output[0], trans.input[0], keep_ops=False)
291+
self._g.replace_all_inputs(None, trans.output[0], trans.input[0], keep_ops=False) # self._g.get_nodes()
292292
self._g.remove_node(trans.name)
293293

294294
def _nodes_has_single_consumer_node(self, nodes):
@@ -408,7 +408,7 @@ def _add_handler(self, trans, node):
408408
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
409409
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
410410
self._g.replace_input(trans, trans.input[0], utils.port_name(conv_node.name), 0)
411-
self._g.replace_all_inputs(self._g.get_nodes(), node.output[0], trans.output[0], keep_ops=False)
411+
self._g.replace_all_inputs(None, node.output[0], trans.output[0], keep_ops=False) # self._g.get_nodes()
412412
self._g.remove_node(t_p.name)
413413
self._g.remove_node(node.name)
414414
return True
@@ -417,7 +417,7 @@ def _add_handler(self, trans, node):
417417
def _transpose_handler(self, trans, node):
418418
if is_nchw_transpose(node):
419419
for g in {self._g, node.graph}:
420-
g.replace_all_inputs(g.get_nodes(), node.output[0], trans.input[0], keep_ops=False)
420+
g.replace_all_inputs(None, node.output[0], trans.input[0], keep_ops=False) # g.get_nodes()
421421

422422
shape = node.graph.get_shape(node.output[0])
423423
dtype = node.graph.get_dtype(node.output[0])
@@ -475,7 +475,7 @@ def _mul_handler(self, trans, node):
475475
conv.inputs[1].set_tensor_value(np.transpose(result, (3, 2, 0, 1)))
476476

477477
ops = self._g.get_nodes()
478-
self._g.replace_all_inputs(ops, node.output[0], trans.output[0], keep_ops=False)
478+
self._g.replace_all_inputs(None, node.output[0], trans.output[0], keep_ops=False) # ops
479479
self._g.remove_node(node.name)
480480
return True
481481

@@ -523,7 +523,7 @@ def _sum_handler(self, trans, node):
523523

524524
# switch to trans(sum(x1, x2, x3, ...))
525525
ops = self._g.get_nodes()
526-
self._g.replace_all_inputs(ops, node.output[0], trans.output[0], keep_ops=False)
526+
self._g.replace_all_inputs(None, node.output[0], trans.output[0], keep_ops=False) # ops
527527
new_input = [n.output[0] if n.is_const() else n.input[0] for n in inputs]
528528
self._g.replace_inputs(node, new_input)
529529
self._g.replace_input(trans, trans.input[0], node.output[0], 0)
@@ -547,7 +547,7 @@ def _identity_handler(self, trans, node):
547547
if node.output[0] in node.graph.outputs:
548548
return False
549549
for g in {self._g, node.graph}:
550-
g.replace_all_inputs(g.get_nodes(), node.output[0], trans.output[0], keep_ops=False)
550+
g.replace_all_inputs(None, node.output[0], trans.output[0], keep_ops=False) # g.get_nodes()
551551
node.graph.remove_node(node.name)
552552
return True
553553

@@ -587,7 +587,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
587587
if node.get_attr("axes"):
588588
# switch tran and squeeze
589589
# 1 switch
590-
self._g.replace_all_inputs(self._g.get_nodes(), node.output[0], trans.output[0], keep_ops=False)
590+
self._g.replace_all_inputs(None, node.output[0], trans.output[0], keep_ops=False) # self._g.get_nodes()
591591
self._g.replace_input(node, node.input[0], trans.input[0], 0)
592592
self._g.replace_input(trans, trans.input[0], node.output[0], 0)
593593
# 2 correct attr of nodes

0 commit comments

Comments
 (0)