Skip to content

Commit 0fd906d

Browse files
Fix random uniform rewriter for scalar shapes (#1429)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 97d38f1 commit 0fd906d

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tf2onnx/rewriter/random_uniform.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import numpy as np
88
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
from tf2onnx.graph_builder import GraphBuilder
910
from tf2onnx import utils, handler
1011

1112

@@ -74,9 +75,15 @@ def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete):
7475
shape = g.get_shape(output.output[0])
7576
if shape_node.is_const():
7677
# if the tensorflow input (aka the shape) is const we can use the RandomUniform op
78+
needs_squeeze = False
79+
if len(shape) == 0:
80+
shape = [1]
81+
needs_squeeze = True
7782
new_node = g.make_node("RandomUniform", [], name=op_name,
7883
attr={"low": tmin, "high": tmax, "dtype": dtype, "shape": shape},
7984
shapes=[shape], dtypes=[dtype])
85+
if needs_squeeze:
86+
new_node = GraphBuilder(g).make_squeeze({"data": new_node.output[0], "axes": [0]}, return_node=True)
8087
else:
8188
if shape_node.type == "Shape":
8289
# if shape is dynamic - in tensorflow shape comes as tensor VALUE,

0 commit comments

Comments
 (0)