@@ -306,7 +306,7 @@ def test_duplicated_duplicated_input(self):
306306
307307 graph = helper .make_graph (
308308 [node0 , node1 , node2 , node3 , node4 ],
309- "transpose-merge-test " ,
309+ "test_duplicated_duplicated_input " ,
310310 [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
311311 [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 , 5 ))],
312312 )
@@ -325,14 +325,74 @@ def test_duplicated_duplicated_attributes(self):
325325
326326 graph = helper .make_graph (
327327 [node0 , node1 , node2 , node3 , node4 ],
328- "transpose-merge-test " ,
328+ "test_duplicated_duplicated_attributes " ,
329329 [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
330330 [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (5 ,))],
331331 )
332332
333333 model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
334334 self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
335335 op_type = "ReduceSum" , remaining_op_num = 2 )
336+
337+ def test_duplicated_node_is_graph_output (self ):
338+ node0 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value0" ])
339+ node1 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value1" ])
340+ node2 = helper .make_node ('Add' , inputs = ["value1" , "X" ], outputs = ["value2" ])
341+
342+ graph = helper .make_graph (
343+ [node0 , node1 , node2 ],
344+ "test_duplicated_node_is_graph_output" ,
345+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 , 5 ))],
346+ [helper .make_tensor_value_info ("value1" , TensorProto .FLOAT , (5 , 5 )),
347+ helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (5 , 5 ))],
348+ )
349+
350+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
351+ self .run_merge_duplicated_nodes_compare (["value1" , "value2" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
352+ op_type = "Add" , remaining_op_num = 2 )
353+
354+ def test_duplicated_different_output_length (self ):
355+ node0 = helper .make_node ('Dropout' , inputs = ["X" ], outputs = ["value0" ])
356+ node1 = helper .make_node ('Dropout' , inputs = ["X" ], outputs = ["value1" , "mask" ])
357+ node2 = helper .make_node ('Dropout' , inputs = ["value1" ], outputs = ["value2" ])
358+
359+ graph = helper .make_graph (
360+ [node0 , node1 , node2 ],
361+ "test_duplicated_different_output_length" ,
362+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 ,))],
363+ [helper .make_tensor_value_info ("value1" , TensorProto .FLOAT , (5 ,)),
364+ helper .make_tensor_value_info ("mask" , TensorProto .BOOL , (5 ,)),
365+ helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (5 ,))],
366+ )
367+
368+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
369+ self .run_merge_duplicated_nodes_compare (["value1" , "mask" , "value2" ],
370+ {"X" : np .random .randn (5 ,).astype (np .float32 )},
371+ model_proto ,
372+ op_type = "Dropout" , remaining_op_num = 2 )
373+
374+ def test_duplicated_need_multiple_run (self ):
375+ node00 = helper .make_node ('Log' , inputs = ["X" ], outputs = ["value00" ])
376+ node01 = helper .make_node ('Log' , inputs = ["value00" ], outputs = ["value01" ])
377+ node02 = helper .make_node ('Log' , inputs = ["value01" ], outputs = ["value02" ])
378+
379+ node10 = helper .make_node ('Log' , inputs = ["X" ], outputs = ["value10" ])
380+ node11 = helper .make_node ('Log' , inputs = ["value10" ], outputs = ["value11" ])
381+ node12 = helper .make_node ('Log' , inputs = ["value11" ], outputs = ["value12" ])
382+
383+ res = helper .make_node ('Add' , inputs = ["value02" , "value12" ], outputs = ["res" ])
384+
385+ graph = helper .make_graph (
386+ [node00 , node01 , node02 , node10 , node11 , node12 , res ],
387+ "test_duplicated_node_is_graph_output" ,
388+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (5 ,))],
389+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (5 ,))],
390+ )
391+
392+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
393+ self .run_merge_duplicated_nodes_compare (["res" ], {"X" : np .random .randn (5 ,).astype (np .float32 )},
394+ model_proto ,
395+ op_type = "Log" , remaining_op_num = 3 )
336396 # Merge Duplicated Nodes Optimizer Tests End
337397
338398
0 commit comments