@@ -805,8 +805,10 @@ def version_1(cls, ctx, node, **kwargs):
805
805
# insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs)
806
806
# ctx.insert_new_node_on_output("Squeeze", node.output[0], name)
807
807
name = utils .make_name (node .name )
808
- squeeze_node = ctx .insert_new_node_on_output ("Squeeze" , node .output [0 ], name )
809
- squeeze_node .set_attr ("axes" , needs_squeeze )
808
+ squeeze_node = GraphBuilder (ctx ).make_squeeze (
809
+ {"axes" : needs_squeeze , 'data' : node .output [0 ]}, name = name , return_node = True )
810
+ ctx .insert_node_on_output (squeeze_node )
811
+
810
812
nodes .append (squeeze_node )
811
813
input_dtype = ctx .get_dtype (node .output [0 ])
812
814
ctx .set_dtype (squeeze_node .output [0 ], input_dtype )
@@ -1023,8 +1025,9 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
1023
1025
node = GraphBuilder (ctx ).make_slice (kwargs , name = node .name , dtypes = out_dtypes , shapes = out_shapes )
1024
1026
node = ctx .get_node_by_output (node )
1025
1027
if needs_squeeze :
1026
- squeeze_node = ctx .insert_new_node_on_output ("Squeeze" , node .output [0 ], node .child_name ())
1027
- squeeze_node .set_attr ("axes" , needs_squeeze )
1028
+ squeeze_node = GraphBuilder (ctx ).make_squeeze (
1029
+ {"axes" : needs_squeeze , "data" : node .output [0 ]}, name = node .child_name (), return_node = True )
1030
+ ctx .insert_node_on_output (squeeze_node , node .output [0 ])
1028
1031
input_dtype = ctx .get_dtype (node .output [0 ])
1029
1032
ctx .set_dtype (squeeze_node .output [0 ], input_dtype )
1030
1033
ctx .copy_shape (node .output [0 ], squeeze_node .output [0 ])
@@ -1348,7 +1351,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
1348
1351
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1349
1352
if len (input_shape ) == 3 :
1350
1353
# insert automatically an Unsqueeze op if the input is 3d
1351
- unsqz1 = ctx .make_node ("Unsqueeze" , input_tensor .output , {"axes" : [3 ]})
1354
+ unsqz1 = GraphBuilder (ctx ).make_unsqueeze (
1355
+ {"axes" : [3 ], "data" : input_tensor .output [0 ]}, return_node = True )
1352
1356
trans1 = ctx .make_node ("Transpose" , unsqz1 .output , {"perm" : [3 , 0 , 1 , 2 ]})
1353
1357
else :
1354
1358
trans1 = ctx .make_node ("Transpose" , input_tensor .output , {"perm" : [3 , 0 , 1 , 2 ]})
@@ -1377,22 +1381,21 @@ def any_version(cls, opset, ctx, node, **kwargs):
1377
1381
kwargs = {** inputs_map }
1378
1382
ctx .remove_node (node .name )
1379
1383
slice1 = GraphBuilder (ctx ).make_slice (kwargs )
1380
- ctx . make_node ( "Squeeze" , [ slice1 ], { "axes" : [ 3 ]},
1381
- outputs = node .output , name = node .name , dtypes = dtypes , shapes = shapes )
1384
+ GraphBuilder ( ctx ). make_squeeze (
1385
+ { "axes" : [ 3 ], "data" : slice1 , " outputs" : node .output } , name = node .name , dtypes = dtypes , shapes = shapes )
1382
1386
else :
1383
1387
kwargs = {** inputs_map , "outputs" : node .output }
1384
1388
ctx .remove_node (node .name )
1385
1389
GraphBuilder (ctx ).make_slice (kwargs , name = node .name , dtypes = dtypes , shapes = shapes )
1386
1390
else :
1387
1391
def mknode (optype , inputs , attrs = None ):
1388
1392
nodename = utils .make_name (node .name + '_' + optype .lower ())
1389
- if opset < 13 or optype != 'Squeeze' :
1390
- return ctx .make_node (optype , inputs , attrs , name = nodename )
1391
- inputs .append (attrs ['axes' ])
1392
- attrs = attrs .copy ()
1393
- attrs .pop ('axes' )
1393
+ if opset >= 13 and optype == 'Squeeze' :
1394
+ return GraphBuilder (ctx ).make_squeeze (
1395
+ {"axes" : attrs ['axes' ], "data" : inputs [0 ]}, name = nodename , return_node = True )
1394
1396
return ctx .make_node (optype , inputs , attrs , name = nodename )
1395
1397
1398
+
1396
1399
# support non 3D/4D tensors and dynamic crop vals
1397
1400
# dynamic slice starts at opset 10
1398
1401
utils .make_sure (ctx .opset >= 11 , 'non-4D tensor or non-const crops require opset 11' )
0 commit comments