@@ -1192,7 +1192,8 @@ def version_1(cls, ctx, node, **kwargs):
1192
1192
# for each output we need to squeeze axis
1193
1193
for n in node .output :
1194
1194
op_name = utils .make_name (node .name )
1195
- squeeze_node = ctx .insert_new_node_on_output ("Squeeze" , n , name = op_name , axes = [axis ])
1195
+ squeeze_node = GraphBuilder (ctx ).make_squeeze ({'data' : n , 'axes' : [axis ]}, name = op_name , return_node = True )
1196
+ ctx .insert_node_on_output (squeeze_node , n )
1196
1197
ctx .copy_shape (n , squeeze_node .output [0 ])
1197
1198
ctx .copy_dtype (n , squeeze_node .output [0 ])
1198
1199
@@ -1256,8 +1257,8 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
1256
1257
depth = GraphBuilder (ctx ).make_unsqueeze ({'data' : node .input [1 ], 'axes' : [0 ]})
1257
1258
on_value = node .input [2 ]
1258
1259
off_value = node .input [3 ]
1259
- on_value = ctx . make_node ( "Unsqueeze" , [ on_value ], attr = { " axes" : [0 ]}). output [ 0 ]
1260
- off_value = ctx . make_node ( "Unsqueeze" , [ off_value ], attr = { " axes" : [0 ]}). output [ 0 ]
1260
+ on_value = GraphBuilder ( ctx ). make_unsqueeze ({ 'data' : on_value , ' axes' : [0 ]})
1261
+ off_value = GraphBuilder ( ctx ). make_unsqueeze ({ 'data' : off_value , ' axes' : [0 ]})
1261
1262
off_on_value = ctx .make_node ("Concat" , [off_value , on_value ], attr = {"axis" : 0 }).output [0 ]
1262
1263
1263
1264
indices = node .input [0 ]
@@ -1637,8 +1638,9 @@ def any_version(cls, opset, ctx, node, **kwargs):
1637
1638
# add valid_outputs count
1638
1639
output_idx = 2 if node .type in ["NonMaxSuppressionV5" ] else 1
1639
1640
shape_op = ctx .make_node ("Shape" , inputs = [nms_output .output [0 ]])
1640
- reduce_op = ctx .make_node ("ReduceSum" , inputs = shape_op .output , attr = {"axes" : [0 ], "keepdims" : 0 })
1641
- ctx .make_node ("Cast" , inputs = [reduce_op .output [0 ]], attr = {"to" : onnx_pb .TensorProto .INT32 },
1641
+ reduce_op = GraphBuilder (ctx ).make_reduce_sum (
1642
+ {"data" : shape_op .output [0 ], "axes" : [0 ], "keepdims" : 0 , "noop_with_empty_axes" : 1 })
1643
+ ctx .make_node ("Cast" , inputs = [reduce_op ], attr = {"to" : onnx_pb .TensorProto .INT32 },
1642
1644
outputs = [node .output [output_idx ]], dtypes = dtypes [output_idx ], shapes = shapes [output_idx ],
1643
1645
op_name_scope = node .name )
1644
1646
@@ -2385,15 +2387,17 @@ def normalize():
2385
2387
pad_length_2 = body_graph .make_node ('Concat' , [zeo , pad_length .output [0 ]], attr = {'axis' : 0 })
2386
2388
padded_range = body_graph .make_node ('Pad' , [sliced_range .output [0 ], pad_length_2 .output [0 ]])
2387
2389
# opset == 11, no need to change unsqueeze
2388
- unsqueezed_range = body_graph .make_node ('Unsqueeze' , [padded_range .output [0 ]], attr = {'axes' : [1 ]})
2390
+ unsqueezed_range = GraphBuilder (body_graph ).make_unsqueeze (
2391
+ {'data' : padded_range .output [0 ], 'axes' : [1 ]}, return_node = True )
2389
2392
half_shape_x = body_graph .make_node ('Slice' ,
2390
2393
[new_shape .output [0 ], zeo , minus_two ])
2391
2394
shape_range = body_graph .make_node ('Shape' , [unsqueezed_range .output [0 ]])
2392
2395
full_shape = body_graph .make_node ('Concat' , [half_shape_x .output [0 ], shape_range .output [0 ]], attr = {'axis' : 0 })
2393
2396
expanded_range = body_graph .make_node ('Expand' , [unsqueezed_range .output [0 ], full_shape .output [0 ]])
2394
2397
gathered_input = body_graph .make_node ('GatherElements' , [processed_input .output [0 ], expanded_range .output [0 ]],
2395
2398
attr = {'axis' : - 1 })
2396
- squeezed_input = body_graph .make_node ('Squeeze' , [gathered_input .output [0 ]], attr = {'axes' : [- 1 ]})
2399
+ squeezed_input = GraphBuilder (body_graph ).make_squeeze (
2400
+ {'data' : gathered_input .output [0 ], 'axes' : [- 1 ]}, return_node = True )
2397
2401
left_width = body_graph .make_node ('Sub' , [new_width .output [0 ], abs_k .output [0 ]])
2398
2402
dims = body_graph .make_node ('Concat' , [left_width .output [0 ], new_depth .output [0 ]], attr = {'axis' : 0 })
2399
2403
valid_dim = body_graph .make_node ('ReduceMin' , [dims .output [0 ]])
@@ -2505,8 +2509,8 @@ def normalize():
2505
2509
raw_output_shape + [- 1 ])
2506
2510
squeeze_sliced_graph = ctx .create_new_graph_with_same_config ()
2507
2511
squeeze_sliced_graph .parent_graph = ctx
2508
- squeeze_sliced = squeeze_sliced_graph . make_node ( 'Squeeze' , [ final_output_right_sliced . output [ 0 ]],
2509
- attr = { 'axes' : [- 2 ]})
2512
+ squeeze_sliced = GraphBuilder ( squeeze_sliced_graph ). make_squeeze (
2513
+ { 'data' : final_output_right_sliced . output [ 0 ], 'axes' : [- 2 ]}, return_node = True )
2510
2514
squeeze_sliced_graph .add_graph_output (squeeze_sliced .output [0 ], ctx .get_dtype (node .input [0 ]), raw_output_shape )
2511
2515
shapes = node .output_shapes
2512
2516
dtypes = node .output_dtypes
@@ -2680,14 +2684,14 @@ def version_13(cls, ctx, node, **kwargs):
2680
2684
@tf_op (["MatrixDiag" , "MatrixDiagV2" , "MatrixDiagV3" ])
2681
2685
class MatrixDiag :
2682
2686
@classmethod
2683
- def any_version (cls , opset , ctx , node , ** kwargs ):
2687
+ def version_12 (cls , ctx , node , ** kwargs ):
2684
2688
# Assemble MatrixDiagV3 by ReverseSequence
2685
2689
argc = len (node .input )
2686
2690
2687
- if opset >= 13 :
2688
- squeeze_axes0 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([0 ], dtype = np .int64 ))
2689
- squeeze_axes_1 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([- 1 ], dtype = np .int64 ))
2690
- squeeze_axes_2 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([- 2 ], dtype = np .int64 ))
2691
+ if ctx . opset >= 13 :
2692
+ squeeze_axes0 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([0 ], dtype = np .int64 )). output [ 0 ]
2693
+ squeeze_axes_1 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([- 1 ], dtype = np .int64 )). output [ 0 ]
2694
+ squeeze_axes_2 = ctx .make_const (utils .make_name ("const_axes" ), np .array ([- 2 ], dtype = np .int64 )). output [ 0 ]
2691
2695
2692
2696
minus_two , minus_one , zeo , one , two = \
2693
2697
[n .output [0 ] for n in ctx .make_consts ([[- 2 ], [- 1 ], [0 ], [1 ], [2 ]])]
@@ -2712,7 +2716,7 @@ def processdiag():
2712
2716
diag = node .input [0 ]
2713
2717
shape = ctx .get_shape (diag )
2714
2718
if len (shape ) == 1 :
2715
- if opset < 13 :
2719
+ if ctx . opset < 13 :
2716
2720
diag = mknode ("Unsqueeze" , [diag ], attr = {"axes" : [0 ]})
2717
2721
else :
2718
2722
diag = mknode ("Unsqueeze" , [diag , squeeze_axes0 ])
@@ -2737,7 +2741,7 @@ def id_diag():
2737
2741
def ex_diag ():
2738
2742
g = ctx .create_new_graph_with_same_config ()
2739
2743
g .parent_graph = ctx
2740
- if opset < 13 :
2744
+ if ctx . opset < 13 :
2741
2745
ex = mknode2 (g , "Unsqueeze" , [diag ], attr = {"axes" : [- 2 ]})
2742
2746
else :
2743
2747
ex = mknode2 (g , "Unsqueeze" , [diag , squeeze_axes_2 ])
@@ -2755,7 +2759,7 @@ def squeeze_12(name):
2755
2759
def squeeze_13 (name ):
2756
2760
return ctx .make_node ("Squeeze" , [name , squeeze_axes_1 ]).output [0 ]
2757
2761
2758
- squeeze = squeeze_12 if opset < 13 else squeeze_13
2762
+ squeeze = squeeze_12 if ctx . opset < 13 else squeeze_13
2759
2763
2760
2764
# gather inputs
2761
2765
diag , k , k_min , k_max , k_max_nxt = processdiag ()
@@ -3018,14 +3022,10 @@ def paddiag():
3018
3022
ctx .make_node ("Identity" , [padded ], name = node .name ,
3019
3023
outputs = node .output , shapes = shapes , dtypes = dtypes )
3020
3024
3021
- @classmethod
3022
- def version_12 (cls , ctx , node , ** kwargs ):
3023
- cls .any_version (12 , ctx , node , ** kwargs )
3024
-
3025
3025
@classmethod
3026
3026
def version_13 (cls , ctx , node , ** kwargs ):
3027
3027
# Parameters moved to inputs for operator Squeeze, Unsqueeze.
3028
- cls .any_version ( 13 , ctx , node , ** kwargs )
3028
+ cls .version_12 ( ctx , node , ** kwargs )
3029
3029
3030
3030
3031
3031
@tf_op ("MatrixSetDiagV3" )
0 commit comments