@@ -31,8 +31,7 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
31
31
"""cast int32 shape into int64 shape."""
32
32
name = node .input [input_number ]
33
33
34
- cast_node = ctx .insert_new_node_on_input (node , "Cast" , name )
35
- cast_node .set_attr ("to" , onnx_pb .TensorProto .INT64 )
34
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , name , to = onnx_pb .TensorProto .INT64 )
36
35
ctx .set_dtype (cast_node .output [0 ], onnx_pb .TensorProto .INT64 )
37
36
ctx .copy_shape (name , cast_node .output [0 ])
38
37
@@ -46,14 +45,14 @@ def _wrap_concat_with_cast(ctx, node):
46
45
output_name = node .output [0 ]
47
46
# cast each inputs to float
48
47
for i , inp in enumerate (node .inputs ):
49
- input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [i ])
50
- input_cast . set_attr ( "to" , onnx_pb .TensorProto .FLOAT )
48
+ input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [i ],
49
+ to = onnx_pb .TensorProto .FLOAT )
51
50
ctx .set_dtype (input_cast .output [0 ], onnx_pb .TensorProto .FLOAT )
52
51
next_nodes = ctx .find_output_consumers (node .output [0 ])
53
52
# cast output back to dtype unless the next op is a cast
54
53
if next_nodes [0 ].type != "Cast" :
55
- output_cast = ctx .insert_new_node_on_output ("Cast" , output_name , name = node .child_name ())
56
- output_cast . set_attr ( "to" , dtype )
54
+ output_cast = ctx .insert_new_node_on_output ("Cast" , output_name , name = node .child_name (),
55
+ to = dtype )
57
56
ctx .set_dtype (output_cast .output [0 ], dtype )
58
57
ctx .copy_shape (output_name , output_cast .output [0 ])
59
58
@@ -161,15 +160,14 @@ def version_5(cls, ctx, node, **kwargs):
161
160
return
162
161
163
162
# onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
164
- input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
165
- input_cast .set_attr ("to" , onnx_pb .TensorProto .FLOAT )
163
+ input_cast = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ], to = onnx_pb .TensorProto .FLOAT )
166
164
ctx .copy_shape (node .output [0 ], input_cast .output [0 ])
167
165
168
166
# if the next node is already a cast we don't need to insert another one
169
167
next_nodes = ctx .find_output_consumers (node .output [0 ])
170
168
if len (next_nodes ) != 1 or next_nodes [0 ].type != "Cast" :
171
- output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = node .child_name ())
172
- output_cast . set_attr ( "to" , dtype )
169
+ output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = node .child_name (),
170
+ to = dtype )
173
171
ctx .set_dtype (output_cast .output [0 ], dtype )
174
172
ctx .copy_shape (node .output [0 ], output_cast .output [0 ])
175
173
@@ -743,16 +741,17 @@ def version_1(cls, ctx, node, **kwargs):
743
741
if node .inputs [0 ].type == "Cast" and len (ctx .find_output_consumers (node .inputs [0 ].output [0 ])) == 1 :
744
742
# override the previous cast
745
743
cast_node = node .inputs [0 ]
744
+ cast_node .set_attr ("to" , onnx_pb .TensorProto .FLOAT )
746
745
else :
747
- cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
746
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ],
747
+ to = onnx_pb .TensorProto .FLOAT )
748
748
nodes .insert (0 , cast_node )
749
- cast_node .set_attr ("to" , onnx_pb .TensorProto .FLOAT )
750
749
ctx .set_dtype (cast_node .output [0 ], onnx_pb .TensorProto .FLOAT )
751
750
ctx .copy_shape (node .input [0 ], cast_node .output [0 ])
752
751
# undo the cast afer slice
753
752
name = utils .make_name (node .name )
754
- cast_node = ctx .insert_new_node_on_output ("Cast" , nodes [- 1 ].output [0 ], name )
755
- cast_node . set_attr ( "to" , input_dtype )
753
+ cast_node = ctx .insert_new_node_on_output ("Cast" , nodes [- 1 ].output [0 ], name ,
754
+ to = input_dtype )
756
755
ctx .set_dtype (cast_node .output [0 ], input_dtype )
757
756
ctx .copy_shape (node .output [0 ], cast_node .output [0 ])
758
757
nodes .append (cast_node )
@@ -1181,8 +1180,7 @@ def version_1(cls, ctx, node, **kwargs):
1181
1180
if dtype == onnx_pb .TensorProto .INT64 :
1182
1181
return
1183
1182
op_name = utils .make_name (node .name )
1184
- output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = op_name )
1185
- output_cast .set_attr ("to" , dtype )
1183
+ output_cast = ctx .insert_new_node_on_output ("Cast" , node .output [0 ], name = op_name , to = dtype )
1186
1184
ctx .set_dtype (output_cast .output [0 ], dtype )
1187
1185
ctx .copy_shape (node .output [0 ], output_cast .output [0 ])
1188
1186
@@ -1556,8 +1554,7 @@ def version_8(cls, ctx, node, **kwargs):
1556
1554
1557
1555
seq_len_dtype = ctx .get_dtype (node .input [1 ])
1558
1556
if seq_len_dtype != onnx_pb .TensorProto .INT64 :
1559
- cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ])
1560
- cast_node .set_attr ("to" , onnx_pb .TensorProto .INT64 )
1557
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = onnx_pb .TensorProto .INT64 )
1561
1558
ctx .set_dtype (cast_node .output [0 ], onnx_pb .TensorProto .INT64 )
1562
1559
ctx .copy_shape (node .input [1 ], cast_node .output [0 ])
1563
1560
@@ -1763,8 +1760,8 @@ def version_11(cls, ctx, node, **kwargs):
1763
1760
# cast to int64 if needed
1764
1761
if dtypes [1 ] != onnx_pb .TensorProto .UINT64 :
1765
1762
cast_node = ctx .insert_new_node_on_output ("Cast" , node .output [1 ],
1766
- name = utils .make_name (node .name ) + "_cast" )
1767
- cast_node . set_attr ( "to" , dtypes [1 ])
1763
+ name = utils .make_name (node .name ) + "_cast" ,
1764
+ to = dtypes [1 ])
1768
1765
ctx .set_dtype (cast_node .output [0 ], dtypes [1 ])
1769
1766
ctx .copy_shape (node .output [1 ], cast_node .output [0 ])
1770
1767
# FIXME: the indices in onnx are not the same as in tensorflow.
0 commit comments