Skip to content
8 changes: 4 additions & 4 deletions tf2onnx/onnx_opset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def version_1(cls, ctx, node, **kwargs):
# in the rewriter does not trigger. grappler will send the random uniform
# with shape as input so we need to pickup the input here and if the shape is
# const we make it an attribute.
seed = node.get_attr("seed")
node.set_attr("seed", float(seed.f))
seed = node.get_attr("seed2")
node.set_attr("seed", float(seed.i))
utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9", node.type)
shape = node.inputs[0].get_tensor_value()
ctx.remove_input(node, node.input[0], 0)
Expand All @@ -88,8 +88,8 @@ def version_9(cls, ctx, node, **kwargs):
if node.inputs[0].is_const():
cls.version_1(ctx, node, **kwargs)
else:
seed = node.get_attr("seed")
node.set_attr("seed", float(seed.f))
seed = node.get_attr("seed2")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to also update this for version_1(line 63)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @fatcat-z you're quite right, ill update version_1 as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused: is this behavior (using seed2 instead of seed) same among different TF versions? Did it change after one of tf version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I've seen and tested, version_1 behaviour is also relevant for TF version.2
when the random generator dim/shape are constant and not derived from the data
batch size or any other dependent size.

I'm checking with TF V1 whether it was changed.

node.set_attr("seed", float(seed.i))
cast_node = ctx.make_node("Cast", [node.input[0]], attr={'to': onnx_pb.TensorProto.INT64})
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
inputs = node.input.copy()
Expand Down