Skip to content

Commit 2093b38

Browse files
Implement RandomStandardNormal
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 4bde219 commit 2093b38

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

tests/test_backend.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,45 @@ def func():
19621962
# since results are random, compare the shapes only
19631963
self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
19641964

1965+
def test_random_std_normal(self):
1966+
def func():
1967+
shape = tf.constant([20, 10, 5], name="shape")
1968+
x_ = tf.random.normal(shape)
1969+
return tf.identity(x_, name=_TFOUTPUT)
1970+
# since results are random, compare the shapes only
1971+
g = self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1972+
results = self.run_backend(g, g.outputs, {})[0]
1973+
self.assertTrue(-0.1 < np.mean(results) < 0.1)
1974+
self.assertTrue(0.9 < np.std(results) < 1.1)
1975+
1976+
def test_randomnormal(self):
1977+
def func():
1978+
shape = tf.constant([20, 10, 5], name="shape")
1979+
x_ = tf.random.normal(shape, mean=10, stddev=2)
1980+
return tf.identity(x_, name=_TFOUTPUT)
1981+
# since results are random, compare the shapes only
1982+
g = self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1983+
results = self.run_backend(g, g.outputs, {})[0]
1984+
if not 9.9 < np.mean(results) < 10.1:
1985+
np.testing.assert_allclose(np.mean(results), 0)
1986+
self.assertTrue(1.9 < np.std(results) < 2.1)
1987+
1988+
@check_opset_min_version(9, "RandomNormalLike")
1989+
def test_randomnormal_unknown_shape(self):
1990+
shape_val = np.array([20, 10, 5], np.int32)
1991+
def func(shape):
1992+
x_ = tf.random.normal(shape)
1993+
return tf.identity(x_, name=_TFOUTPUT)
1994+
# since results are random, compare the shapes only
1995+
feed_dict = {_INPUT: shape_val}
1996+
g = self._run_test_case(func, [_OUTPUT], feed_dict, check_value=False, check_shape=True)
1997+
if "input" in g.input_names:
1998+
# TFLite inputs don't have port numbers
1999+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
2000+
results = self.run_backend(g, g.outputs, feed_dict)[0]
2001+
self.assertTrue(-0.1 < np.mean(results) < 0.1)
2002+
self.assertTrue(0.9 < np.std(results) < 1.1)
2003+
19652004
def test_randomuniform_int(self):
19662005
def func():
19672006
shape = tf.constant([100, 3], name="shape")

tf2onnx/onnx_opset/generator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def version_1(cls, ctx, node, **kwargs):
2929
pass
3030

3131

32-
@tf_op(["RandomNormal", "RandomUniform", "RandomUniformInt"])
32+
@tf_op(["RandomNormal", "RandomUniform", "RandomUniformInt", "RandomStandardNormal"])
3333
class RandomOp:
3434
@classmethod
3535
def randuniform_int(cls, ctx, rand_node, rand_out, min_inp, max_inp):
@@ -66,7 +66,7 @@ def version_1(cls, ctx, node, **kwargs):
6666
# const we make it an attribute.
6767
seed = node.get_attr("seed")
6868
node.set_attr("seed", float(seed.f))
69-
utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9")
69+
utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9", node.type)
7070
shape = node.inputs[0].get_tensor_value()
7171
ctx.remove_input(node, node.input[0], 0)
7272
if len(shape) == 0:
@@ -84,6 +84,8 @@ def version_1(cls, ctx, node, **kwargs):
8484
cls.randuniform_int(ctx, node, rand_out, node.input[0], node.input[1])
8585
node.type = "RandomUniform"
8686
ctx.replace_inputs(node, [])
87+
elif node.type == "RandomStandardNormal":
88+
node.type = "RandomNormal"
8789

8890
@classmethod
8991
def version_9(cls, ctx, node, **kwargs):
@@ -99,6 +101,8 @@ def version_9(cls, ctx, node, **kwargs):
99101
if node.type == "RandomUniformInt":
100102
cls.randuniform_int(ctx, node, node.output[0], inputs[1], inputs[2])
101103
node.type = "RandomUniformLike"
104+
elif node.type == "RandomStandardNormal":
105+
node.type = "RandomNormalLike"
102106
else:
103107
node.type = node.type + 'Like'
104108

tf2onnx/rewriter/random_normal_rewriter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,29 @@ def rewrite_random_normal(g, ops):
3939
else:
4040
# pattern 2
4141
mean = 0.0
42+
input2 = match.get_op('input2')
43+
if input2.type == 'Mul':
44+
scale = input2.inputs[1].get_tensor_value()
45+
else:
46+
scale = 1.0
4247
dtype = g.get_dtype(output.output[0])
4348
op_name = utils.make_name("RandomNormal")
4449
out_name = utils.port_name(op_name)
4550

4651
rn_op = match.get_op('input1')
4752
seed = float(rn_op.get_attr('seed2').i)
4853

54+
attr = {"mean": mean, "scale": scale, "dtype": dtype, "seed": seed}
4955
if rn_op.inputs[0].type == "Shape":
5056
shape_node = rn_op.inputs[0]
5157
new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name,
52-
attr={"mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed})
58+
attr=attr)
5359
else:
5460
shape = g.get_shape(output.output[0])
5561
if shape is None or -1 in shape:
5662
continue
57-
new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name,
58-
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed})
63+
attr['shape'] = shape
64+
new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name, attr=attr)
5965

6066
g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
6167
g.safe_remove_nodes(match.get_nodes())

0 commit comments

Comments
 (0)