Skip to content

Commit 7adddb1

Browse files
Improve messaging for RandomUniform
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 2b9860a commit 7adddb1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tf2onnx/onnx_opset/generator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def version_1(cls, ctx, node, **kwargs):
3838
# const we make it an attribute.
3939
seed = node.get_attr("seed")
4040
node.set_attr("seed", float(seed.f))
41-
if len(node.input) > 0 and node.inputs[0].is_const():
42-
shape = node.inputs[0].get_tensor_value()
43-
ctx.remove_input(node, node.input[0], 0)
44-
node.set_attr("shape", shape)
45-
ctx.set_shape(node.output[0], shape)
41+
utils.make_sure(node.inputs[0].is_const(), "RandomUniform with non-const shape requires opset >= 9.")
42+
shape = node.inputs[0].get_tensor_value()
43+
ctx.remove_input(node, node.input[0], 0)
44+
node.set_attr("shape", shape)
45+
ctx.set_shape(node.output[0], shape)
4646

4747
@classmethod
4848
def version_9(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)