@@ -418,6 +418,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
418
418
self ._nodes = []
419
419
self ._nodes_by_name = {}
420
420
self ._output_to_node_name = {}
421
+ self ._input_to_node_name = {}
421
422
self .shapes = {}
422
423
self .graph_name = graph_name or "tf2onnx"
423
424
self ._is_subgraph = is_subgraph
@@ -442,7 +443,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
442
443
443
444
ops = [Node (node , self ) for node in nodes ]
444
445
self .reset_nodes (ops )
445
-
446
+
446
447
if not is_subgraph :
447
448
# add identity node after each output, in case it is renamed during conversion.
448
449
for o in self .outputs :
@@ -569,6 +570,11 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
569
570
570
571
onnx_node = helper .make_node (op_type , inputs , outputs , name = name , domain = domain , ** raw_attr )
571
572
573
+ for name in onnx_node .input :
574
+ if name not in self ._input_to_node_name :
575
+ self ._input_to_node_name [name ] = set ()
576
+ self ._input_to_node_name [name ].add (onnx_node .name )
577
+
572
578
if op_type in ["If" , "Loop" , "Scan" ]:
573
579
# we force the op containing inner graphs not skipped during conversion.
574
580
skip_conversion = False
@@ -606,6 +612,10 @@ def append_node(self, node):
606
612
self ._output_to_node_name [name ] = node .name
607
613
self .set_dtype (name , output_dtypes [i ])
608
614
self .set_shape (name , output_shapes [i ])
615
+ for name in node .input :
616
+ if name not in self ._input_to_node_name :
617
+ self ._input_to_node_name [name ] = set ()
618
+ self ._input_to_node_name [name ].add (node .name )
609
619
610
620
def remove_node (self , node_name ):
611
621
"""Remove node in current graph."""
@@ -626,6 +636,13 @@ def remove_node(self, node_name):
626
636
if op_output in self ._dtypes :
627
637
del self ._dtypes [op_output ]
628
638
639
+ for op_input in node .input :
640
+ if op_input not in self ._input_to_node_name :
641
+ raise RuntimeError (
642
+ "Input %r of node %r not found." % (op_input , node .name ))
643
+ if node .name in self ._input_to_node_name [op_input ]:
644
+ self ._input_to_node_name [op_input ].remove (node .name )
645
+
629
646
self ._nodes .remove (node )
630
647
node .graph = None
631
648
@@ -649,16 +666,26 @@ def reset_nodes(self, ops):
649
666
self .contained_graphs = remained_sub_graphs
650
667
self ._nodes_by_name = {op .name : op for op in ops }
651
668
self ._output_to_node_name = {}
669
+ self ._input_to_node_name = {}
652
670
for op in ops :
653
671
for op_output in op .output :
654
672
self ._output_to_node_name [op_output ] = op .name
673
+ for op_input in op .input :
674
+ if op_input not in self ._input_to_node_name :
675
+ self ._input_to_node_name [op_input ] = set ()
676
+ self ._input_to_node_name [op_input ].add (op .name )
655
677
656
678
for n in self ._order_sensitive_inputs :
657
679
if n not in ops :
658
680
self ._order_sensitive_inputs .remove (n )
659
681
for o in self .outputs :
660
682
if o not in self ._output_to_node_name :
661
- raise ValueError ("graph output " + o + " not exist" )
683
+ raise ValueError ("graph output %r not exist" % o )
684
+ for i in self .inputs :
685
+ if i .name .startswith ('Placeholder' ):
686
+ continue
687
+ if i .name not in self ._input_to_node_name :
688
+ raise ValueError ("graph input %r not exist in graph." % i .name )
662
689
663
690
self ._dtypes = remained_dtypes
664
691
self ._output_shapes = remained_shapes
@@ -775,6 +802,14 @@ def get_node_by_output_in_current_graph(self, output):
775
802
ret = self ._nodes_by_name .get (name )
776
803
return ret
777
804
805
+ def get_node_by_input_in_current_graph (self , input ):
806
+ """Get nodes by node input id."""
807
+ names = self ._output_to_node_name .get (input )
808
+ ret = None
809
+ if name :
810
+ ret = [self ._nodes_by_name .get (name ) for name in names ]
811
+ return ret
812
+
778
813
def get_node_by_name (self , name ):
779
814
"""Get node by name."""
780
815
ret = self ._nodes_by_name .get (name )
@@ -785,6 +820,10 @@ def set_node_by_name(self, node):
785
820
self ._nodes_by_name [node .name ] = node
786
821
for op_output in node .output :
787
822
self ._output_to_node_name [op_output ] = node .name
823
+ for name in node .input :
824
+ if name not in self ._input_to_node_name :
825
+ self ._input_to_node_name [name ] = set ()
826
+ self ._input_to_node_name [name ].add (node .name )
788
827
789
828
def change_node_name (self , node , new_name ):
790
829
"""Remove node in current graph."""
@@ -1210,35 +1249,54 @@ def find_output_consumers(self, output_name):
1210
1249
nodes .extend (g .find_output_consumers (output_name ))
1211
1250
return nodes
1212
1251
1213
- @staticmethod
1214
- def replace_all_inputs (ops , old_input , new_input ):
1252
+ def replace_all_inputs (self , ops , old_input , new_input ):
1215
1253
"""Replace all inputs pointing to old_input with new_input."""
1216
1254
if old_input == new_input :
1217
1255
return
1256
+ if new_input not in self ._input_to_node_name :
1257
+ self ._input_to_node_name [new_input ] = set ()
1258
+
1259
+ to_ops = self ._input_to_node_name .get (old_input , None )
1260
+ if to_ops is None :
1261
+ # This means old_input is a final output.
1262
+ to_ops = set ()
1218
1263
1219
1264
for node in ops :
1220
1265
if old_input in node .input and new_input in node .output :
1221
1266
raise RuntimeError ("creating a circle in the graph is not allowed: " + node .name )
1267
+ self ._input_to_node_name [new_input ].add (node .name )
1222
1268
1223
1269
for i , input_name in enumerate (node .input ):
1224
1270
if input_name == old_input :
1225
1271
node .input [i ] = new_input
1272
+ if node .name not in to_ops :
1273
+ raise RuntimeError (
1274
+ "Unable to replace %r by %r. Node %r is not using input %r." % (
1275
+ old_input , new_input , node .name , old_input ))
1276
+
1226
1277
1227
1278
# modify references in sub graphs
1228
1279
body_graphs = node .get_body_graphs ()
1229
1280
if body_graphs :
1230
1281
for g in body_graphs .values ():
1231
1282
g .replace_all_inputs (g .get_nodes (), old_input , new_input )
1232
1283
1233
- @staticmethod
1234
- def replace_input (node , old_input , new_input ):
1284
+ def replace_input (self , node , old_input , new_input ):
1235
1285
"""Replace node."""
1236
1286
assert isinstance (node , Node ) and isinstance (old_input , six .text_type ) and isinstance (new_input , six .text_type )
1237
1287
is_replaced = False
1238
1288
for i , input_name in enumerate (node .input ):
1239
1289
if input_name == old_input :
1240
1290
node .input [i ] = new_input
1241
1291
is_replaced = True
1292
+
1293
+ to_ops = self ._input_to_node_name .get (old_input , None )
1294
+ if to_ops is not None :
1295
+ to_ops .remove (node .name )
1296
+ if new_input not in self ._input_to_node_name :
1297
+ self ._input_to_node_name [new_input ] = set ()
1298
+ self ._input_to_node_name [new_input ].add (node .name )
1299
+
1242
1300
return is_replaced
1243
1301
1244
1302
def _extract_sub_graph_nodes (self , dest_node , input_checker = None ):
0 commit comments