@@ -27,7 +27,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
2727
2828 origin_model_path = self .save_onnx_model (origin_proto , onnx_feed_dict , postfix = "_origin" )
2929
30- new_proto = GraphUtil .optimize_graph_with_model_proto (origin_proto )
30+ new_proto = GraphUtil .optimize_model_proto (origin_proto )
3131
3232 self .assertTrue (new_proto , msg = "model proto after optimizer should not be None" )
3333
@@ -287,7 +287,115 @@ def test_identity_in_subgraph_non_graph_output(self):
287287 self .run_identity_compare (["Z1" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
288288 model_proto , remaining_identity_num = 0 )
289289
290- # Tranpose Optimizer Tests End
290+ # Identity Optimizer Tests End
291+
292+ # Merge Duplicated Nodes Optimizer Tests Start
293+
294+ def run_merge_duplicated_nodes_compare (self , output_names_with_port , onnx_feed_dict , origin_proto ,
295+ op_type = None , remaining_op_num = None , debug = False , rtol = 1e-07 ):
296+ self .run_and_compare (output_names_with_port , onnx_feed_dict , origin_proto , op_type = op_type ,
297+ remaining_op_num = remaining_op_num , debug = debug , rtol = rtol )
298+
299+ def test_duplicated_duplicated_input (self ):
300+ # same input or not
301+ node0 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value0" ])
302+ node1 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value1" ])
303+ node2 = helper .make_node ('Add' , inputs = ["value1" , "X" ], outputs = ["value2" ])
304+ node3 = helper .make_node ("Mul" , ["value0" , "value2" ], ["value3" ])
305+ node4 = helper .make_node ("Mul" , ["value1" , "value3" ], ["OUT" ])
306+
307+ graph = helper .make_graph (
308+ [node0 , node1 , node2 , node3 , node4 ],
309+ "test_duplicated_duplicated_input" ,
310+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
311+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 , 5 ))],
312+ )
313+
314+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
315+ self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
316+ op_type = "Add" , remaining_op_num = 2 )
317+
318+ def test_duplicated_duplicated_attributes (self ):
319+ # same attr or not
320+ node0 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value0" ], axes = [0 ], keepdims = 0 )
321+ node1 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value1" ], axes = [0 ], keepdims = 0 )
322+ node2 = helper .make_node ('ReduceSum' , inputs = ["X" ], outputs = ["value2" ], axes = [1 ], keepdims = 0 )
323+ node3 = helper .make_node ('Add' , inputs = ["value0" , "value1" ], outputs = ["value3" ])
324+ node4 = helper .make_node ("Mul" , ["value2" , "value3" ], ["OUT" ])
325+
326+ graph = helper .make_graph (
327+ [node0 , node1 , node2 , node3 , node4 ],
328+ "test_duplicated_duplicated_attributes" ,
329+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
330+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 ,))],
331+ )
332+
333+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
334+ self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
335+ op_type = "ReduceSum" , remaining_op_num = 2 )
336+
337+ def test_duplicated_node_is_graph_output (self ):
338+ node0 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value0" ])
339+ node1 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value1" ])
340+ node2 = helper .make_node ('Add' , inputs = ["value1" , "X" ], outputs = ["value2" ])
341+
342+ graph = helper .make_graph (
343+ [node0 , node1 , node2 ],
344+ "test_duplicated_node_is_graph_output" ,
345+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
346+ [helper .make_tensor_value_info ("value1" , TensorProto .FLOAT , (5 , 5 )),
347+ helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (5 , 5 ))],
348+ )
349+
350+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
351+ self .run_merge_duplicated_nodes_compare (["value1" , "value2" ],
352+ {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
353+ op_type = "Add" , remaining_op_num = 2 )
354+
355+ def test_duplicated_different_output_length (self ):
356+ node0 = helper .make_node ('Dropout' , inputs = ["X" ], outputs = ["value0" ])
357+ node1 = helper .make_node ('Dropout' , inputs = ["X" ], outputs = ["value1" , "mask" ])
358+ node2 = helper .make_node ('Dropout' , inputs = ["value1" ], outputs = ["value2" ])
359+
360+ graph = helper .make_graph (
361+ [node0 , node1 , node2 ],
362+ "test_duplicated_different_output_length" ,
363+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 ,))],
364+ [helper .make_tensor_value_info ("value1" , TensorProto .FLOAT , (5 ,)),
365+ helper .make_tensor_value_info ("mask" , TensorProto .BOOL , (5 ,)),
366+ helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (5 ,))],
367+ )
368+
369+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
370+ self .run_merge_duplicated_nodes_compare (["value1" , "mask" , "value2" ],
371+ {"X" : np .random .randn (5 ,).astype (np .float32 )},
372+ model_proto ,
373+ op_type = "Dropout" , remaining_op_num = 2 )
374+
375+ def test_duplicated_need_multiple_run (self ):
376+ node00 = helper .make_node ('Log' , inputs = ["X" ], outputs = ["value00" ])
377+ node01 = helper .make_node ('Log' , inputs = ["value00" ], outputs = ["value01" ])
378+ node02 = helper .make_node ('Log' , inputs = ["value01" ], outputs = ["value02" ])
379+
380+ node10 = helper .make_node ('Log' , inputs = ["X" ], outputs = ["value10" ])
381+ node11 = helper .make_node ('Log' , inputs = ["value10" ], outputs = ["value11" ])
382+ node12 = helper .make_node ('Log' , inputs = ["value11" ], outputs = ["value12" ])
383+
384+ res = helper .make_node ('Add' , inputs = ["value02" , "value12" ], outputs = ["res" ])
385+
386+ graph = helper .make_graph (
387+ [node00 , node01 , node02 , node10 , node11 , node12 , res ],
388+ "test_duplicated_node_is_graph_output" ,
389+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 ,))],
390+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (5 ,))],
391+ )
392+
393+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
394+ self .run_merge_duplicated_nodes_compare (["res" ], {"X" : np .random .randn (5 ,).astype (np .float32 )},
395+ model_proto ,
396+ op_type = "Log" , remaining_op_num = 3 )
397+ # Merge Duplicated Nodes Optimizer Tests End
398+
291399
292400if __name__ == "__main__" :
293401 unittest_main ()
0 commit comments