@@ -1391,20 +1391,37 @@ def version_11(cls, ctx, node, **kwargs):
13911391 else :
13921392 mode = "nearest"
13931393 roi = ctx .make_const (utils .make_name ("roi" ), np .array ([]).astype (np .float32 ))
1394- const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ]).astype (np .int64 ))
1395- const_two = ctx .make_const (utils .make_name ("const_two" ), np .array ([2 ]).astype (np .int64 ))
1396- const_empty_float = ctx .make_const (utils .make_name ("const_empty_float" ), np .array ([]).astype (np .float32 ))
13971394 input_nchw = ctx .make_node ("Transpose" , [node .input [0 ]], {"perm" : constants .NHWC_TO_NCHW })
1398- shape_input = ctx .make_node ("Shape" , [input_nchw .output [0 ]])
1399- sliced_shape = ctx .make_node ("Slice" , [shape_input .output [0 ], const_zero .output [0 ], const_two .output [0 ]])
1400- size_int64 = ctx .make_node ("Cast" , [node .input [1 ]], attr = {"to" : onnx_pb .TensorProto .INT64 })
1401- concat_shape = ctx .make_node ("Concat" , [sliced_shape .output [0 ], size_int64 .output [0 ]], {'axis' : 0 })
1402- resize_inputs = [
1403- input_nchw .output [0 ],
1404- roi .output [0 ],
1405- const_empty_float .output [0 ],
1406- concat_shape .output [0 ]
1407- ]
1395+ shape = ctx .get_shape (node .input [0 ])
1396+ if shape and shape [2 ] != - 1 and shape [1 ] != - 1 and node .inputs [1 ].is_const ():
1397+ target_shape = node .inputs [1 ].get_tensor_value ()
1398+ n , h , w , c = shape
1399+ nh , nw = target_shape
1400+ if "sizes" in node .attr :
1401+ sizes_val = np .array ([1.0 , 1.0 , nh , nw ]).astype (np .int64 )
1402+ resize_params = ctx .make_const (utils .make_name ("sizes" ), sizes_val , raw = False )
1403+ else : # scales
1404+ scale_val = np .array ([1.0 , 1.0 , float (nh ) / h , float (nw ) / w ]).astype (np .float32 )
1405+ resize_params = ctx .make_const (utils .make_name ("scales" ), scale_val , raw = False )
1406+ resize_inputs = [
1407+ input_nchw .output [0 ],
1408+ roi .output [0 ],
1409+ resize_params .output [0 ]
1410+ ]
1411+ else :
1412+ const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array ([0 ]).astype (np .int64 ))
1413+ const_two = ctx .make_const (utils .make_name ("const_two" ), np .array ([2 ]).astype (np .int64 ))
1414+ const_empty_float = ctx .make_const (utils .make_name ("const_empty_float" ), np .array ([]).astype (np .float32 ))
1415+ shape_input = ctx .make_node ("Shape" , [input_nchw .output [0 ]])
1416+ sliced_shape = ctx .make_node ("Slice" , [shape_input .output [0 ], const_zero .output [0 ], const_two .output [0 ]])
1417+ size_int64 = ctx .make_node ("Cast" , [node .input [1 ]], attr = {"to" : onnx_pb .TensorProto .INT64 })
1418+ concat_shape = ctx .make_node ("Concat" , [sliced_shape .output [0 ], size_int64 .output [0 ]], {'axis' : 0 })
1419+ resize_inputs = [
1420+ input_nchw .output [0 ],
1421+ roi .output [0 ],
1422+ const_empty_float .output [0 ],
1423+ concat_shape .output [0 ]
1424+ ]
14081425 transformation_mode = "asymmetric"
14091426 nearest_mode = "floor"
14101427 if "align_corners" in node .attr and node .attr ["align_corners" ].i :
0 commit comments