Skip to content

Commit ff634a5

Browse files
committed
create operator Cast and specify value for parameter to in one call
1 parent 60b6813 commit ff634a5

File tree

5 files changed

+32
-37
lines changed

5 files changed

+32
-37
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ def maybe_cast_input(self, supported, type_map):
411411
if tdtype is None:
412412
raise RuntimeError("don't know how to cast type {} on node {}".format(dtype, name))
413413
shape = self.graph.get_shape(name)
414-
cast_node = self.graph.insert_new_node_on_input(self, "Cast", name)
415-
cast_node.set_attr("to", tdtype)
414+
cast_node = self.graph.insert_new_node_on_input(
415+
self, "Cast", name, to=tdtype)
416416
self.graph.set_dtype(cast_node.output[0], [tdtype])
417417
self.graph.set_shape(cast_node.output[0], shape)
418418
did_cast = True

tf2onnx/onnx_opset/math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,9 @@ def version_11(cls, ctx, node, **kwargs):
545545
shapes=shapes, dtypes=dtypes, domain=constants.ONNX_DOMAIN, attr={'direction': direction})
546546

547547
if node.maybe_cast_input([supported, supported], type_map):
548-
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
549-
name=utils.make_name(node.name) + "_castback")
550-
cast_back_node.set_attr("to", dtypes[0])
548+
cast_back_node = ctx.insert_new_node_on_output(
549+
"Cast", node.output[0], name=utils.make_name(node.name) + "_castback",
550+
to=dtypes[0])
551551
ctx.set_dtype(cast_back_node.output[0], dtypes[0])
552552
ctx.copy_shape(node.name, cast_back_node.output[0])
553553

tf2onnx/onnx_opset/nn.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -637,14 +637,13 @@ def version_1(cls, ctx, node, **kwargs):
637637
origin_dtype = ctx.get_dtype(node.output[0])
638638
if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT,
639639
onnx_pb.TensorProto.DOUBLE]:
640-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
641-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
640+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
642641
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
643642
ctx.copy_shape(node.name, cast_node.output[0])
644643

645644
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
646-
name=utils.make_name(node.name) + "_castback")
647-
cast_back_node.set_attr("to", origin_dtype)
645+
name=utils.make_name(node.name) + "_castback",
646+
to=origin_dtype)
648647
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
649648
ctx.copy_shape(node.name, cast_back_node.output[0])
650649

@@ -667,14 +666,13 @@ def version_11(cls, ctx, node, **kwargs):
667666
origin_dtype = ctx.get_dtype(node.output[0])
668667
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
669668
TensorProto.INT32, TensorProto.INT64]:
670-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
671-
cast_node.set_attr("to", TensorProto.FLOAT)
669+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.FLOAT)
672670
ctx.set_dtype(cast_node.output[0], TensorProto.FLOAT)
673671
ctx.copy_shape(node.name, cast_node.output[0])
674672

675673
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
676-
name=utils.make_name(node.name) + "_castback")
677-
cast_back_node.set_attr("to", origin_dtype)
674+
name=utils.make_name(node.name) + "_castback",
675+
to=origin_dtype)
678676
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
679677
ctx.copy_shape(node.name, cast_back_node.output[0])
680678

tf2onnx/onnx_opset/tensor.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
3131
"""cast int32 shape into int64 shape."""
3232
name = node.input[input_number]
3333

34-
cast_node = ctx.insert_new_node_on_input(node, "Cast", name)
35-
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
34+
cast_node = ctx.insert_new_node_on_input(node, "Cast", name, to=onnx_pb.TensorProto.INT64)
3635
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
3736
ctx.copy_shape(name, cast_node.output[0])
3837

@@ -46,14 +45,14 @@ def _wrap_concat_with_cast(ctx, node):
4645
output_name = node.output[0]
4746
# cast each inputs to float
4847
for i, inp in enumerate(node.inputs):
49-
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i])
50-
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
48+
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i],
49+
to=onnx_pb.TensorProto.FLOAT)
5150
ctx.set_dtype(input_cast.output[0], onnx_pb.TensorProto.FLOAT)
5251
next_nodes = ctx.find_output_consumers(node.output[0])
5352
# cast output back to dtype unless the next op is a cast
5453
if next_nodes[0].type != "Cast":
55-
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name())
56-
output_cast.set_attr("to", dtype)
54+
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name(),
55+
to=dtype)
5756
ctx.set_dtype(output_cast.output[0], dtype)
5857
ctx.copy_shape(output_name, output_cast.output[0])
5958

@@ -161,15 +160,14 @@ def version_5(cls, ctx, node, **kwargs):
161160
return
162161

163162
# onnx < opset 8 does not know reshape for other types than float*, wrap the reshape in casts
164-
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
165-
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
163+
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=onnx_pb.TensorProto.FLOAT)
166164
ctx.copy_shape(node.output[0], input_cast.output[0])
167165

168166
# if the next node is already a cast we don't need to insert another one
169167
next_nodes = ctx.find_output_consumers(node.output[0])
170168
if len(next_nodes) != 1 or next_nodes[0].type != "Cast":
171-
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name())
172-
output_cast.set_attr("to", dtype)
169+
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=node.child_name(),
170+
to=dtype)
173171
ctx.set_dtype(output_cast.output[0], dtype)
174172
ctx.copy_shape(node.output[0], output_cast.output[0])
175173

@@ -743,16 +741,17 @@ def version_1(cls, ctx, node, **kwargs):
743741
if node.inputs[0].type == "Cast" and len(ctx.find_output_consumers(node.inputs[0].output[0])) == 1:
744742
# override the previous cast
745743
cast_node = node.inputs[0]
744+
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
746745
else:
747-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
746+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0],
747+
to=onnx_pb.TensorProto.FLOAT)
748748
nodes.insert(0, cast_node)
749-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
750749
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
751750
ctx.copy_shape(node.input[0], cast_node.output[0])
752751
# undo the cast afer slice
753752
name = utils.make_name(node.name)
754-
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name)
755-
cast_node.set_attr("to", input_dtype)
753+
cast_node = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name,
754+
to=input_dtype)
756755
ctx.set_dtype(cast_node.output[0], input_dtype)
757756
ctx.copy_shape(node.output[0], cast_node.output[0])
758757
nodes.append(cast_node)
@@ -1181,8 +1180,7 @@ def version_1(cls, ctx, node, **kwargs):
11811180
if dtype == onnx_pb.TensorProto.INT64:
11821181
return
11831182
op_name = utils.make_name(node.name)
1184-
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name)
1185-
output_cast.set_attr("to", dtype)
1183+
output_cast = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name, to=dtype)
11861184
ctx.set_dtype(output_cast.output[0], dtype)
11871185
ctx.copy_shape(node.output[0], output_cast.output[0])
11881186

@@ -1556,8 +1554,7 @@ def version_8(cls, ctx, node, **kwargs):
15561554

15571555
seq_len_dtype = ctx.get_dtype(node.input[1])
15581556
if seq_len_dtype != onnx_pb.TensorProto.INT64:
1559-
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1])
1560-
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
1557+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
15611558
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.INT64)
15621559
ctx.copy_shape(node.input[1], cast_node.output[0])
15631560

@@ -1763,8 +1760,8 @@ def version_11(cls, ctx, node, **kwargs):
17631760
# cast to int64 if needed
17641761
if dtypes[1] != onnx_pb.TensorProto.UINT64:
17651762
cast_node = ctx.insert_new_node_on_output("Cast", node.output[1],
1766-
name=utils.make_name(node.name) + "_cast")
1767-
cast_node.set_attr("to", dtypes[1])
1763+
name=utils.make_name(node.name) + "_cast",
1764+
to=dtypes[1])
17681765
ctx.set_dtype(cast_node.output[0], dtypes[1])
17691766
ctx.copy_shape(node.output[1], cast_node.output[0])
17701767
# FIXME: the indices in onnx are not the same as in tensorflow.

tf2onnx/tfonnx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
160160
input_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
161161
g.set_dtype(input_name, onnx_pb.TensorProto.FLOAT)
162162
else:
163-
cast_node = g.insert_new_node_on_input(op, "Cast", input_name)
164-
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
163+
cast_node = g.insert_new_node_on_input(op, "Cast", input_name,
164+
to=onnx_pb.TensorProto.FLOAT)
165165
g.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
166166
g.copy_shape(input_name, cast_node.output[0])
167167
cast_inserted.append(cast_node)
@@ -171,8 +171,8 @@ def rewrite_incomplete_type_support(g, ops, impacted_ops):
171171
name = utils.make_name(op.name)
172172
logger.debug("insert cast back for node %s on output %s [dtype=%s]", op.name, output_name,
173173
output_dtype)
174-
output_cast = g.insert_new_node_on_output("Cast", output_name, name=name)
175-
output_cast.set_attr("to", output_dtype)
174+
output_cast = g.insert_new_node_on_output("Cast", output_name, name=name,
175+
to=output_dtype)
176176
g.set_dtype(output_cast.output[0], output_dtype)
177177
g.copy_shape(output_name, output_cast.output[0])
178178
cast_inserted.append(output_cast)

0 commit comments

Comments
 (0)