@@ -126,7 +126,7 @@ def merge_duplicated_transposes(self):
126
126
transpose_out = transposes [0 ].output [0 ]
127
127
for node in transposes [1 :]:
128
128
old_transpose_out = node .output [0 ]
129
- graph .replace_all_inputs (graph . get_nodes () , old_transpose_out , transpose_out , keep_ops = False )
129
+ graph .replace_all_inputs (None , old_transpose_out , transpose_out , keep_ops = False ) # graph.get_nodes( )
130
130
131
131
# dangling transpose nodes can be deleted
132
132
graph .delete_unused_nodes (graph .outputs )
@@ -209,7 +209,7 @@ def _handle_node_having_branches(self, node):
209
209
for n in input_transposes :
210
210
n_input = n .input [0 ]
211
211
utils .make_sure (len (n .output ) == 1 , "only expect single output" )
212
- self ._g .replace_all_inputs (self . _g . get_nodes () , n .output [0 ], n_input , keep_ops = False )
212
+ self ._g .replace_all_inputs (None , n .output [0 ], n_input , keep_ops = False ) # self._g.get_nodes( )
213
213
self ._g .remove_node (n .name )
214
214
215
215
utils .make_sure (len (node .output ) == 1 , "only expect single output" )
@@ -220,7 +220,7 @@ def _handle_node_having_branches(self, node):
220
220
for n in output_transposes :
221
221
n_input = n .input [0 ]
222
222
utils .make_sure (len (n .output ) == 1 , "only expect single output" )
223
- self ._g .replace_all_inputs (self . _g . get_nodes () , n .output [0 ], n_input , keep_ops = False )
223
+ self ._g .replace_all_inputs (None , n .output [0 ], n_input , keep_ops = False ) # self._g.get_nodes( )
224
224
self ._g .remove_node (n .name )
225
225
226
226
shape = self ._g .get_shape (node .output [0 ])
@@ -249,7 +249,7 @@ def _switch_transpose_and_node(self, node, trans):
249
249
250
250
input_index = self ._get_input_index_for_trans (node , trans )
251
251
252
- self ._g .replace_all_inputs (self . _g . get_nodes () , node .output [0 ], trans .output [0 ], keep_ops = False )
252
+ self ._g .replace_all_inputs (None , node .output [0 ], trans .output [0 ], keep_ops = False ) # self._g.get_nodes( )
253
253
self ._g .replace_input (node , node .input [input_index ], trans .input [0 ], input_index )
254
254
self ._g .replace_input (trans , trans .input [0 ], node .output [0 ], 0 )
255
255
@@ -288,7 +288,7 @@ def _handle_nhwc_tranpose(self, trans):
288
288
return False
289
289
290
290
def _remove_useless_tranpose (self , trans ):
291
- self ._g .replace_all_inputs (self . _g . get_nodes () , trans .output [0 ], trans .input [0 ], keep_ops = False )
291
+ self ._g .replace_all_inputs (None , trans .output [0 ], trans .input [0 ], keep_ops = False ) # self._g.get_nodes( )
292
292
self ._g .remove_node (trans .name )
293
293
294
294
def _nodes_has_single_consumer_node (self , nodes ):
@@ -408,7 +408,7 @@ def _add_handler(self, trans, node):
408
408
conv_inputs = [t_p .input [0 ], t_p .input [1 ], node .input [1 ]]
409
409
conv_node = self ._g .make_node (t_p .type , conv_inputs , attr = t_p .attr_onnx )
410
410
self ._g .replace_input (trans , trans .input [0 ], utils .port_name (conv_node .name ), 0 )
411
- self ._g .replace_all_inputs (self . _g . get_nodes () , node .output [0 ], trans .output [0 ], keep_ops = False )
411
+ self ._g .replace_all_inputs (None , node .output [0 ], trans .output [0 ], keep_ops = False ) # self._g.get_nodes( )
412
412
self ._g .remove_node (t_p .name )
413
413
self ._g .remove_node (node .name )
414
414
return True
@@ -417,7 +417,7 @@ def _add_handler(self, trans, node):
417
417
def _transpose_handler (self , trans , node ):
418
418
if is_nchw_transpose (node ):
419
419
for g in {self ._g , node .graph }:
420
- g .replace_all_inputs (g . get_nodes () , node .output [0 ], trans .input [0 ], keep_ops = False )
420
+ g .replace_all_inputs (None , node .output [0 ], trans .input [0 ], keep_ops = False ) # g.get_nodes( )
421
421
422
422
shape = node .graph .get_shape (node .output [0 ])
423
423
dtype = node .graph .get_dtype (node .output [0 ])
@@ -475,7 +475,7 @@ def _mul_handler(self, trans, node):
475
475
conv .inputs [1 ].set_tensor_value (np .transpose (result , (3 , 2 , 0 , 1 )))
476
476
477
477
ops = self ._g .get_nodes ()
478
- self ._g .replace_all_inputs (ops , node .output [0 ], trans .output [0 ], keep_ops = False )
478
+ self ._g .replace_all_inputs (None , node .output [0 ], trans .output [0 ], keep_ops = False ) # ops
479
479
self ._g .remove_node (node .name )
480
480
return True
481
481
@@ -523,7 +523,7 @@ def _sum_handler(self, trans, node):
523
523
524
524
# switch to trans(sum(x1, x2, x3, ...))
525
525
ops = self ._g .get_nodes ()
526
- self ._g .replace_all_inputs (ops , node .output [0 ], trans .output [0 ], keep_ops = False )
526
+ self ._g .replace_all_inputs (None , node .output [0 ], trans .output [0 ], keep_ops = False ) # ops
527
527
new_input = [n .output [0 ] if n .is_const () else n .input [0 ] for n in inputs ]
528
528
self ._g .replace_inputs (node , new_input )
529
529
self ._g .replace_input (trans , trans .input [0 ], node .output [0 ], 0 )
@@ -547,7 +547,7 @@ def _identity_handler(self, trans, node):
547
547
if node .output [0 ] in node .graph .outputs :
548
548
return False
549
549
for g in {self ._g , node .graph }:
550
- g .replace_all_inputs (g . get_nodes () , node .output [0 ], trans .output [0 ], keep_ops = False )
550
+ g .replace_all_inputs (None , node .output [0 ], trans .output [0 ], keep_ops = False ) # g.get_nodes( )
551
551
node .graph .remove_node (node .name )
552
552
return True
553
553
@@ -587,7 +587,7 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
587
587
if node .get_attr ("axes" ):
588
588
# switch tran and squeeze
589
589
# 1 switch
590
- self ._g .replace_all_inputs (self . _g . get_nodes () , node .output [0 ], trans .output [0 ], keep_ops = False )
590
+ self ._g .replace_all_inputs (None , node .output [0 ], trans .output [0 ], keep_ops = False ) # self._g.get_nodes( )
591
591
self ._g .replace_input (node , node .input [0 ], trans .input [0 ], 0 )
592
592
self ._g .replace_input (trans , trans .input [0 ], node .output [0 ], 0 )
593
593
# 2 correct attr of nodes
0 commit comments