@@ -590,13 +590,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
590
590
onnx_node = helper .make_node (op_type , inputs , outputs , name = name , domain = domain , ** raw_attr )
591
591
592
592
for name2 in onnx_node .input :
593
- if name2 not in self ._input_to_node_name :
594
- self ._input_to_node_name [name2 ] = set ()
595
- self ._input_to_node_name [name2 ].add (onnx_node .name )
596
- if self .parent_graph is not None :
597
- if name2 not in self .parent_graph ._input_to_graph :
598
- self .parent_graph ._input_to_graph [name2 ] = {}
599
- self .parent_graph ._input_to_graph [name2 ][id (self )] = self
593
+ self ._register_input_name (name2 , onnx_node )
600
594
601
595
if op_type in ["If" , "Loop" , "Scan" ]:
602
596
# we force the op containing inner graphs not skipped during conversion.
@@ -637,13 +631,7 @@ def append_node(self, node):
637
631
self .set_dtype (name , output_dtypes [i ])
638
632
self .set_shape (name , output_shapes [i ])
639
633
for name in node .input :
640
- if name not in self ._input_to_node_name :
641
- self ._input_to_node_name [name ] = set ()
642
- self ._input_to_node_name [name ].add (node .name )
643
- if self .parent_graph is not None :
644
- if name not in self .parent_graph ._input_to_graph :
645
- self .parent_graph ._input_to_graph [name ] = {}
646
- self .parent_graph ._input_to_graph [name ][id (self )] = self
634
+ self ._register_input_name (name , node )
647
635
648
636
def remove_node (self , node_name ):
649
637
"""Remove node in current graph."""
@@ -667,9 +655,8 @@ def remove_node(self, node_name):
667
655
for op_input in node .input :
668
656
if op_input not in self ._input_to_node_name :
669
657
raise RuntimeError (
670
- "Input %r of node %r not found." % (op_input , node .name ))
671
- if node .name in self ._input_to_node_name [op_input ]:
672
- self ._input_to_node_name [op_input ].remove (node .name )
658
+ "Input %r of node %r not found." % (op_input , node_name ))
659
+ self ._unregister_input_name (op_input , node )
673
660
674
661
self ._nodes .remove (node )
675
662
node .graph = None
@@ -705,13 +692,7 @@ def reset_nodes(self, ops):
705
692
else :
706
693
inps = op .input
707
694
for op_input in inps :
708
- if op_input not in self ._input_to_node_name :
709
- self ._input_to_node_name [op_input ] = set ()
710
- self ._input_to_node_name [op_input ].add (op .name )
711
- if self .parent_graph is not None :
712
- if op_input not in self .parent_graph ._input_to_graph :
713
- self .parent_graph ._input_to_graph [op_input ] = {}
714
- self .parent_graph ._input_to_graph [op_input ][id (self )] = self
695
+ self ._register_input_name (op_input , op )
715
696
716
697
for n in self ._order_sensitive_inputs :
717
698
if n not in ops :
@@ -861,13 +842,7 @@ def set_node_by_name(self, node):
861
842
for op_output in node .output :
862
843
self ._output_to_node_name [op_output ] = node .name
863
844
for name in node .input :
864
- if name not in self ._input_to_node_name :
865
- self ._input_to_node_name [name ] = set ()
866
- self ._input_to_node_name [name ].add (node .name )
867
- if self .parent_graph is not None :
868
- if name not in self .parent_graph ._input_to_graph :
869
- self .parent_graph ._input_to_graph [name ] = {}
870
- self .parent_graph ._input_to_graph [name ][id (self )] = self
845
+ self ._register_input_name (name , node )
871
846
872
847
def change_node_name (self , node , new_name ):
873
848
"""Remove node in current graph."""
@@ -1233,10 +1208,7 @@ def remove_input(self, node, to_be_removed, i=None):
1233
1208
1234
1209
for i2 , name in enumerate (node .input ):
1235
1210
if name == to_be_removed :
1236
- if node .input [i2 ] in self ._input_to_node_name :
1237
- to_ops = self ._input_to_node_name [node .input [i2 ]]
1238
- if node .name in to_ops :
1239
- to_ops .remove (node .name )
1211
+ self ._unregister_input_name (node .input [i2 ], node )
1240
1212
del node .input [i2 ]
1241
1213
break
1242
1214
# don't remove output from parent since others might depend on it
@@ -1308,16 +1280,47 @@ def find_output_consumers(self, output_name):
1308
1280
if output_name in node .input :
1309
1281
nodes .append (node )
1310
1282
1311
- for node in self .get_nodes ():
1312
- # find consumers in sub graphs,
1313
- # should we keep an index of nodes including
1314
- # a subgraphs?
1315
- body_graphs = node .get_body_graphs ()
1316
- if body_graphs :
1317
- for g in body_graphs .values ():
1318
- nodes .extend (g .find_output_consumers (output_name ))
1283
+ if output_name in self ._input_to_graph :
1284
+ for idg , g in self ._input_to_graph [output_name ].items ():
1285
+ nodes .extend (g .find_output_consumers (output_name ))
1286
+ else :
1287
+ for node in self .get_nodes ():
1288
+ # find consumers in sub graphs,
1289
+ # should we keep an index of nodes including
1290
+ # a subgraphs?
1291
+ body_graphs = node .get_body_graphs ()
1292
+ if body_graphs :
1293
+ for g in body_graphs .values ():
1294
+ ext = g .find_output_consumers (output_name )
1295
+ if len (ext ) > 0 :
1296
+ raise RuntimeError (
1297
+ "Inconsistency in _input_to_graph." )
1298
+ # nodes.extend(ext)
1319
1299
return nodes
1320
1300
1301
+ def _register_input_name (self , input_name , node , only_graph = False ):
1302
+ if not only_graph :
1303
+ if input_name not in self ._input_to_node_name :
1304
+ self ._input_to_node_name [input_name ] = set ()
1305
+ self ._input_to_node_name [input_name ].add (node .name )
1306
+ if self .parent_graph is not None :
1307
+ if input_name not in self .parent_graph ._input_to_graph :
1308
+ self .parent_graph ._input_to_graph [input_name ] = {}
1309
+ self .parent_graph ._input_to_graph [input_name ][id (self )] = self
1310
+ self .parent_graph ._register_input_name (input_name , node , only_graph = True )
1311
+
1312
+ def _unregister_input_name (self , input_name , node , only_graph = False ):
1313
+ node_name = node .name
1314
+ if not only_graph :
1315
+ if node_name in self ._input_to_node_name [input_name ]:
1316
+ if node_name in self ._input_to_node_name [input_name ]:
1317
+ self ._input_to_node_name [input_name ].remove (node_name )
1318
+ if (self .parent_graph is not None and
1319
+ input_name in self .parent_graph ._input_to_graph and
1320
+ id (self ) in self .parent_graph ._input_to_graph [input_name ]):
1321
+ del self .parent_graph ._input_to_graph [input_name ][id (self )]
1322
+ self .parent_graph ._unregister_input_name (input_name , node , only_graph = True )
1323
+
1321
1324
def replace_all_inputs (self , ops , old_input , new_input , keep_ops = False ):
1322
1325
"""
1323
1326
Replace all inputs pointing to old_input with new_input.
@@ -1327,10 +1330,6 @@ def replace_all_inputs(self, ops, old_input, new_input, keep_ops=False):
1327
1330
return
1328
1331
if new_input not in self ._input_to_node_name :
1329
1332
self ._input_to_node_name [new_input ] = set ()
1330
- if self .parent_graph is not None :
1331
- if new_input not in self .parent_graph ._input_to_graph :
1332
- self .parent_graph ._input_to_graph [new_input ] = {}
1333
- self .parent_graph ._input_to_graph [new_input ][id (self )] = self
1334
1333
1335
1334
if keep_ops and ops is not None :
1336
1335
pass
@@ -1344,22 +1343,22 @@ def replace_all_inputs(self, ops, old_input, new_input, keep_ops=False):
1344
1343
continue
1345
1344
if old_input in node .input and new_input in node .output :
1346
1345
raise RuntimeError ("creating a circle in the graph is not allowed: " + node .name )
1347
- self ._input_to_node_name [new_input ].add (node .name )
1348
- if self .parent_graph is not None :
1349
- if new_input not in self .parent_graph ._input_to_graph :
1350
- self .parent_graph ._input_to_graph [new_input ] = {}
1351
- self .parent_graph ._input_to_graph [new_input ][id (self )] = self
1346
+ self ._register_input_name (new_input , node )
1352
1347
1353
1348
for i , input_name in enumerate (node .input ):
1354
1349
if input_name == old_input :
1355
1350
self .replace_input (node , node .input [i ], new_input , i )
1356
1351
1357
- for node in self .get_nodes ():
1358
- # modify references in sub graphs
1359
- body_graphs = node .get_body_graphs ()
1360
- if body_graphs :
1361
- for g in body_graphs .values ():
1362
- g .replace_all_inputs (g .get_nodes (), old_input , new_input , keep_ops = keep_ops )
1352
+ if old_input in self ._input_to_graph :
1353
+ for idg , g in self ._input_to_graph [old_input ].items ():
1354
+ g .replace_all_inputs (g .get_nodes (), old_input , new_input , keep_ops = keep_ops )
1355
+ #~ else:
1356
+ #~ for node in self.get_nodes():
1357
+ #~ # modify references in sub graphs
1358
+ #~ body_graphs = node.get_body_graphs()
1359
+ #~ if body_graphs:
1360
+ #~ for g in body_graphs.values():
1361
+ #~ g.replace_all_inputs(g.get_nodes(), old_input, new_input, keep_ops=keep_ops)
1363
1362
1364
1363
def replace_input (self , node , old_input , new_input , i = None ):
1365
1364
"""Replace one input in a node."""
@@ -1382,14 +1381,7 @@ def replace_input(self, node, old_input, new_input, i=None):
1382
1381
# A node may take twice the same entry.
1383
1382
to_ops .remove (node .name )
1384
1383
1385
- if new_input not in self ._input_to_node_name :
1386
- self ._input_to_node_name [new_input ] = set ()
1387
- self ._input_to_node_name [new_input ].add (node .name )
1388
- if self .parent_graph is not None :
1389
- if new_input not in self .parent_graph ._input_to_graph :
1390
- self .parent_graph ._input_to_graph [new_input ] = {}
1391
- self .parent_graph ._input_to_graph [new_input ][id (self )] = self
1392
-
1384
+ self ._register_input_name (new_input , node )
1393
1385
return is_replaced
1394
1386
1395
1387
def replace_inputs (self , node , new_inputs ):
@@ -1405,13 +1397,8 @@ def replace_inputs(self, node, new_inputs):
1405
1397
1406
1398
for input_name in new_inputs :
1407
1399
assert isinstance (input_name , six .text_type )
1408
- if input_name not in self ._input_to_node_name :
1409
- self ._input_to_node_name [input_name ] = set ()
1410
- self ._input_to_node_name [input_name ].add (node .name )
1411
- if self .parent_graph is not None :
1412
- if input_name not in self .parent_graph ._input_to_graph :
1413
- self .parent_graph ._input_to_graph [input_name ] = {}
1414
- self .parent_graph ._input_to_graph [input_name ][id (self )] = self
1400
+ self ._register_input_name (input_name , node )
1401
+
1415
1402
node .input = new_inputs
1416
1403
return True
1417
1404
0 commit comments