@@ -482,7 +482,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
482
482
body_graph .parent_graph = self
483
483
new_node .set_body_graph_as_attr (attr_name , body_graph )
484
484
485
- self .replace_all_inputs (self .get_nodes (), o , new_output_name )
485
+ self .replace_all_inputs (self .get_nodes (), o , new_output_name , keep_ops = True )
486
486
self .make_node ("Identity" , [new_output_name ], outputs = [o ], op_name_scope = n .name + "_" + "graph_outputs" )
487
487
self .copy_shape (new_output_name , o )
488
488
self .copy_dtype (new_output_name , o )
@@ -1269,7 +1269,7 @@ def insert_new_node_on_output(self, op_type, output_name, name, domain=None, **k
1269
1269
new_node = self .make_node (op_type , [output_name ], attr = kwargs , outputs = [new_output ], name = name , domain = domain )
1270
1270
1271
1271
to_replace = [n for n in self .get_nodes () if n != new_node ]
1272
- self .replace_all_inputs (to_replace , output_name , new_output )
1272
+ self .replace_all_inputs (to_replace , output_name , new_output , keep_ops = True )
1273
1273
return new_node
1274
1274
1275
1275
def find_output_consumers (self , output_name ):
@@ -1286,8 +1286,11 @@ def find_output_consumers(self, output_name):
1286
1286
nodes .extend (g .find_output_consumers (output_name ))
1287
1287
return nodes
1288
1288
1289
- def replace_all_inputs (self , ops , old_input , new_input ):
1290
- """Replace all inputs pointing to old_input with new_input."""
1289
+ def replace_all_inputs (self , ops , old_input , new_input , keep_ops = True ):
1290
+ """
1291
+ Replace all inputs pointing to old_input with new_input.
1292
+ *ops* is unused unless keep_ops is True.
1293
+ """
1291
1294
if old_input == new_input :
1292
1295
return
1293
1296
if new_input not in self ._input_to_node_name :
0 commit comments