Skip to content

Commit 62d8b4b

Browse files
committed
introduce keep_ops=True
1 parent ff634a5 commit 62d8b4b

21 files changed

+48
-44
lines changed

tf2onnx/graph.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
482482
body_graph.parent_graph = self
483483
new_node.set_body_graph_as_attr(attr_name, body_graph)
484484

485-
self.replace_all_inputs(self.get_nodes(), o, new_output_name)
485+
self.replace_all_inputs(self.get_nodes(), o, new_output_name, keep_ops=True)
486486
self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs")
487487
self.copy_shape(new_output_name, o)
488488
self.copy_dtype(new_output_name, o)
@@ -1269,7 +1269,7 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
12691269
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
12701270

12711271
to_replace = [n for n in self.get_nodes() if n != new_node]
1272-
self.replace_all_inputs(to_replace, output_name, new_output)
1272+
self.replace_all_inputs(to_replace, output_name, new_output, keep_ops=True)
12731273
return new_node
12741274

12751275
def find_output_consumers(self, output_name):
@@ -1286,8 +1286,11 @@ def find_output_consumers(self, output_name):
12861286
nodes.extend(g.find_output_consumers(output_name))
12871287
return nodes
12881288

1289-
def replace_all_inputs(self, ops, old_input, new_input):
1290-
"""Replace all inputs pointing to old_input with new_input."""
1289+
def replace_all_inputs(self, ops, old_input, new_input, keep_ops=True):
1290+
"""
1291+
Replace all inputs pointing to old_input with new_input.
1292+
*ops* is unused unless keep_ops is True.
1293+
"""
12911294
if old_input == new_input:
12921295
return
12931296
if new_input not in self._input_to_node_name:

tf2onnx/onnx_opset/controlflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ def prefix_graph(g, scope):
813813
if old_output == oname:
814814
g.outputs[i] = new_output
815815
break
816-
g.replace_all_inputs(ops, old_output, new_output)
816+
g.replace_all_inputs(ops, old_output, new_output, keep_ops=True)
817817
to_remove.append(node)
818818
for node in to_remove:
819819
g.remove_node(node.name)

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def version_1(cls, ctx, node, **kwargs):
447447
downstream_nodes = ctx.find_output_consumers(node.output[0])
448448
downstream_nodes.remove(output_shape)
449449
downstream_nodes.remove(slice_node)
450-
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0])
450+
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0], keep_ops=True)
451451

452452
conv_dims_attr(node, "strides", spatial=spatial)
453453
conv_dims_attr(node, "dilations", spatial=spatial)

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def version_1(cls, ctx, node, **kwargs):
272272
axis_node = node.inputs[-1]
273273
utils.make_sure(axis_node.is_const(), "%r needs to be const", axis_node.name)
274274
axis_val = axis_node.get_tensor_value()
275-
ctx.remove_input(node, node.input[-1])
275+
ctx.remove_input(node, node.input[-1], len(node.input) - 1)
276276

277277
if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports.
278278
input_shape = ctx.get_shape(node.input[0])

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _optimize_transpose(g, node, consumer_nodes):
138138
shape = g.get_shape(node2.output[0])
139139
dtype = g.get_dtype(node2.output[0])
140140
node2_consumers = g.find_output_consumers(node2.output[0])
141-
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0])
141+
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0], keep_ops=True)
142142
g.remove_node(node2.name)
143143
if set(node2.output) & set(g.outputs):
144144
g.make_node("Identity", [node.input[0]],
@@ -173,7 +173,7 @@ def _optimize_squeeze_unsqueeze(g, node, consumer_nodes):
173173
return []
174174

175175
node2_consumers = g.find_output_consumers(node2.output[0])
176-
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0])
176+
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0], keep_ops=True)
177177
g.remove_node(node.name)
178178
g.remove_node(node2.name)
179179
return []

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _replace_node_with_const(node, graph, vals):
9090
const_node = graph.make_const(utils.make_name("const_fold_opt"), val)
9191
graph.set_dtype(const_node.output[0], utils.map_numpy_to_onnx_dtype(val.dtype))
9292
graph.set_shape(const_node.output[0], val.shape)
93-
graph.replace_all_inputs(graph.get_nodes(), old_input, const_node.output[0])
93+
graph.replace_all_inputs(graph.get_nodes(), old_input, const_node.output[0], keep_ops=True)
9494
graph.remove_node(node.name)
9595

9696
@staticmethod

tf2onnx/optimizer/identity_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _optimize_at_current_graph_level(self, g):
4747
def _handle_non_graph_output_identity(graph, identity):
4848
old_name = identity.output[0]
4949
new_name = identity.input[0]
50-
graph.replace_all_inputs(graph.get_nodes(), old_name, new_name)
50+
graph.replace_all_inputs(graph.get_nodes(), old_name, new_name, keep_ops=True)
5151
graph.remove_node(identity.name)
5252
return True
5353

@@ -81,5 +81,5 @@ def _handle_graph_output_identity(self, graph, identity, graph_outputs):
8181
graph.set_shape(output_id, output_shape)
8282
graph.set_dtype(output_id, output_dtype)
8383

84-
graph.replace_all_inputs(graph.get_nodes(), input_id, output_id)
84+
graph.replace_all_inputs(graph.get_nodes(), input_id, output_id, keep_ops=True)
8585
return True

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _merge_nodes_that_are_duplicated(self, nodes_to_process, graph):
9999
if set(node_to_delete.output).intersection(set(graph.outputs)):
100100
continue
101101
for old_input, new_input in zip(node_to_delete.output, node_to_retain.output):
102-
graph.replace_all_inputs(graph.get_nodes(), old_input, new_input)
102+
graph.replace_all_inputs(graph.get_nodes(), old_input, new_input, keep_ops=True)
103103
graph.remove_node(node_to_delete.name)
104104
self._graph_can_be_optimized = True
105105

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,8 @@ def _add_handler(self, trans, node):
408408

409409
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
410410
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
411-
ops = self._g.get_nodes()
412411
self._g.replace_input(trans, trans.input[0], utils.port_name(conv_node.name), 0)
413-
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
412+
self._g.replace_all_inputs(self._g.get_nodes(), node.output[0], trans.output[0])
414413
self._g.remove_node(t_p.name)
415414
self._g.remove_node(node.name)
416415
return True
@@ -419,8 +418,7 @@ def _add_handler(self, trans, node):
419418
def _transpose_handler(self, trans, node):
420419
if is_nchw_transpose(node):
421420
for g in {self._g, node.graph}:
422-
ops = g.get_nodes()
423-
g.replace_all_inputs(ops, node.output[0], trans.input[0])
421+
g.replace_all_inputs(g.get_nodes(), node.output[0], trans.input[0])
424422

425423
shape = node.graph.get_shape(node.output[0])
426424
dtype = node.graph.get_dtype(node.output[0])

tf2onnx/rewriter/custom_rnn_rewriter.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,23 @@ def _connect_scan_with_output(self, context, scan_node):
151151
if self.g.opset == 8:
152152
nodes = self._adapt_scan_sequence_input_or_output("state_output_reshape",
153153
scan_node.output[index], True)
154-
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
154+
self.g.replace_all_inputs(
155+
self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
155156
else: # since opset 9
156-
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
157+
self.g.replace_all_inputs(
158+
self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
157159
index += 1
158160

159161
for out_tensor_value_info in context.loop_properties.scan_outputs_exits:
160162
if out_tensor_value_info.id:
161163
if self.g.opset == 8:
162164
nodes = self._adapt_scan_sequence_input_or_output("scan_output_reshape",
163165
scan_node.output[index], True)
164-
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
166+
self.g.replace_all_inputs(
167+
self.g.get_nodes(), out_tensor_value_info.id, nodes[-1].output[0])
165168
else: # since opset 9
166-
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
169+
self.g.replace_all_inputs(
170+
self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
167171
index += 1
168172

169173
def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_output=False):

0 commit comments

Comments
 (0)