Skip to content

Commit 3128233

Browse files
Merge branch 'master' into tom/MatrixBandPart
2 parents 8dc9e2b + 5d2b73c commit 3128233

File tree

4 files changed

+95
-14
lines changed

4 files changed

+95
-14
lines changed

tests/test_backend.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,16 +1800,61 @@ def func():
18001800
# since results are random, compare the shapes only
18011801
self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
18021802

1803-
@unittest.skip("TF RandomUniformInt is not supported")
18041803
def test_randomuniform_int(self):
18051804
def func():
1806-
shape = tf.constant([2, 3], name="shape")
1807-
x_ = random_uniform(shape, name="rand", dtype=tf.int32, maxval=10)
1805+
shape = tf.constant([100, 3], name="shape")
1806+
x_ = random_uniform(shape, name="rand", dtype=tf.int32, minval=2, maxval=10)
18081807
x_ = tf.identity(x_, name="output1")
18091808
x_ = tf.identity(x_, name="output2")
18101809
return tf.identity(x_, name=_TFOUTPUT)
18111810
# since results are random, compare the shapes only
1812-
self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1811+
g = self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1812+
results = self.run_backend(g, [_OUTPUT], {})
1813+
numbers = set(results[0].flatten())
1814+
self.assertEqual(sorted(numbers), list(range(2, 10)))
1815+
1816+
def test_randomuniform_int_nonconst_max(self):
1817+
m_val = np.array(8, dtype=np.int32)
1818+
def func(m):
1819+
shape = tf.constant([100, 3], name="shape")
1820+
x_ = random_uniform(shape, name="rand", dtype=tf.int32, minval=0, maxval=m)
1821+
x_ = tf.identity(x_, name="output1")
1822+
x_ = tf.identity(x_, name="output2")
1823+
return tf.identity(x_, name=_TFOUTPUT)
1824+
g = self._run_test_case(func, [_OUTPUT], {_INPUT: m_val}, check_value=False, check_shape=True)
1825+
results = self.run_backend(g, [_OUTPUT], {_INPUT: m_val})
1826+
numbers = set(results[0].flatten())
1827+
self.assertEqual(sorted(numbers), list(range(8)))
1828+
1829+
def test_randomuniform_int_nonconst_min_max(self):
1830+
n_val = np.array(2, dtype=np.int32)
1831+
m_val = np.array(10, dtype=np.int32)
1832+
def func(n, m):
1833+
shape = tf.constant([100, 3], name="shape")
1834+
x_ = random_uniform(shape, name="rand", dtype=tf.int32, minval=n, maxval=m)
1835+
x_ = tf.identity(x_, name="output1")
1836+
x_ = tf.identity(x_, name="output2")
1837+
return tf.identity(x_, name=_TFOUTPUT)
1838+
g = self._run_test_case(func, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val}, check_value=False, check_shape=True)
1839+
results = self.run_backend(g, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val})
1840+
numbers = set(results[0].flatten())
1841+
self.assertEqual(sorted(numbers), list(range(2, 10)))
1842+
1843+
@check_opset_min_version(9, "RandomUniformLike")
1844+
def test_randomuniform_int_nonconst_min_max_shape(self):
1845+
n_val = np.array(2, dtype=np.int32)
1846+
m_val = np.array(10, dtype=np.int32)
1847+
s_val = np.array([100, 3], dtype=np.int64)
1848+
def func(n, m, s):
1849+
x_ = random_uniform(s, name="rand", dtype=tf.int32, minval=n, maxval=m)
1850+
x_ = tf.identity(x_, name="output1")
1851+
x_ = tf.identity(x_, name="output2")
1852+
return tf.identity(x_, name=_TFOUTPUT)
1853+
g = self._run_test_case(func, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val, _INPUT2: s_val},
1854+
check_value=False, check_shape=True)
1855+
results = self.run_backend(g, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val, _INPUT2: s_val})
1856+
numbers = set(results[0].flatten())
1857+
self.assertEqual(sorted(numbers), list(range(2, 10)))
18131858

18141859
@skip_caffe2_backend()
18151860
@check_opset_after_tf_version("2.2", 9, "RandomUniform")
@@ -2981,7 +3026,7 @@ def func(input_x):
29813026

29823027
@check_opset_min_version(11, "CumSum")
29833028
def test_matrix_band_part_3(self):
2984-
for low, high in [(-1, 3), (2, 3), (4, 3), (0, -1), (0, 0)]:
3029+
for low, high in [(-1, 3), (2, 3), (4, 3), (0, -1), (0, 0), (-1, -1)]:
29853030
input_val = np.random.randint(0, 666, (10, 15)).astype(np.int32)
29863031
def func(input_x):
29873032
res = tf.linalg.band_part(input_x, low, high)

tf2onnx/onnx_opset/generator.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,33 @@ def version_1(cls, ctx, node, **kwargs):
2828
pass
2929

3030

31-
@tf_op(["RandomNormal", "RandomUniform"])
31+
@tf_op(["RandomNormal", "RandomUniform", "RandomUniformInt"])
3232
class RandomOp:
33+
@classmethod
34+
def randuniform_int(cls, ctx, node, min_inp, max_inp):
35+
dtype = ctx.get_dtype(node.output[0])
36+
min_node = ctx.get_node_by_output(min_inp)
37+
max_node = ctx.get_node_by_output(max_inp)
38+
if min_node.is_const() and max_node.is_const():
39+
node.set_attr('low', float(min_node.get_tensor_value()))
40+
node.set_attr('high', float(max_node.get_tensor_value()))
41+
out = node.output[0]
42+
elif min_node.is_const() and min_node.get_tensor_value() == 0:
43+
max_float = ctx.make_node("Cast", [max_inp], attr={'to': onnx_pb.TensorProto.FLOAT}).output[0]
44+
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], inputs=[node.output[0], max_float])
45+
out = mul_node.output[0]
46+
else:
47+
min_float = ctx.make_node("Cast", [min_inp], attr={'to': onnx_pb.TensorProto.FLOAT}).output[0]
48+
max_float = ctx.make_node("Cast", [max_inp], attr={'to': onnx_pb.TensorProto.FLOAT}).output[0]
49+
diff = ctx.make_node("Sub", [max_float, min_float]).output[0]
50+
diff_float = ctx.make_node("Cast", [diff], attr={'to': onnx_pb.TensorProto.FLOAT}).output[0]
51+
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], inputs=[node.output[0], diff_float])
52+
mul = mul_node.output[0]
53+
add_node = ctx.insert_new_node_on_output("Add", mul, inputs=[mul, min_float])
54+
out = add_node.output[0]
55+
floor_node = ctx.insert_new_node_on_output("Floor", out)
56+
ctx.insert_new_node_on_output("Cast", floor_node.output[0], to=dtype)
57+
3358
@classmethod
3459
def version_1(cls, ctx, node, **kwargs):
3560
# in tf-2.0 grappler optimizes the graph pretty well and our matching logic
@@ -43,6 +68,10 @@ def version_1(cls, ctx, node, **kwargs):
4368
ctx.remove_input(node, node.input[0], 0)
4469
node.set_attr("shape", shape)
4570
ctx.set_shape(node.output[0], shape)
71+
if node.type == "RandomUniformInt":
72+
cls.randuniform_int(ctx, node, node.input[0], node.input[1])
73+
node.type = "RandomUniform"
74+
ctx.replace_inputs(node, [])
4675

4776
@classmethod
4877
def version_9(cls, ctx, node, **kwargs):
@@ -51,10 +80,15 @@ def version_9(cls, ctx, node, **kwargs):
5180
else:
5281
seed = node.get_attr("seed")
5382
node.set_attr("seed", float(seed.f))
54-
cast_node = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
83+
cast_node = ctx.make_node("Cast", [node.input[0]], attr={'to': onnx_pb.TensorProto.INT64})
5584
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
85+
inputs = node.input.copy()
5686
ctx.replace_inputs(node, const_node.output.copy())
57-
node.type = node.type + 'Like'
87+
if node.type == "RandomUniformInt":
88+
cls.randuniform_int(ctx, node, inputs[1], inputs[2])
89+
node.type = "RandomUniformLike"
90+
else:
91+
node.type = node.type + 'Like'
5892

5993

6094
@tf_op(["RandomNormalLike", "RandomUniformLike"])

tf2onnx/onnx_opset/nn.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ def version_11(cls, ctx, node, **kwargs):
11881188
@tf_op("MatrixBandPart")
11891189
class MatrixBandPart:
11901190
@classmethod
1191-
def version_7(cls, opset, ctx, node, **kwargs):
1191+
def version_7(cls, ctx, node, **kwargs):
11921192
# T output = MatrixBandPart(T input, int num_lower, int num_upper)
11931193
# data-flow: first generate mask matrix and then use element-wise mul op
11941194
input_rank = len(ctx.get_shape(node.input[0]))
@@ -1274,7 +1274,6 @@ def version_11(cls, ctx, node, **kwargs):
12741274
rank = ctx.get_rank(data)
12751275
int_max_val = utils.get_max_value(np.int64)
12761276
dtype = ctx.get_dtype(data)
1277-
np_dtype = utils.map_onnx_to_numpy_type(dtype)
12781277
if rank == 2:
12791278
shape = ctx.make_node("Shape", [data]).output[0]
12801279
else:
@@ -1288,8 +1287,6 @@ def version_11(cls, ctx, node, **kwargs):
12881287
zero_tensor = helper.make_tensor("value", dtype, dims=[1], vals=[0])
12891288
const_of_shape = ctx.make_node("ConstantOfShape", [shape], attr={'value': zero_tensor}).output[0]
12901289
identity_node = ctx.make_node("EyeLike", [const_of_shape]).output[0]
1291-
one_const = ctx.make_const(utils.make_name("one"), np.array(1, np_dtype)).output[0]
1292-
mask = ctx.make_node("Sub", [one_const, identity_node]).output[0]
12931290
shapes = node.output_shapes
12941291
dtypes = node.output_dtypes
12951292
ctx.remove_node(node.name)
@@ -1316,12 +1313,16 @@ def version_11(cls, ctx, node, **kwargs):
13161313
if num_upper_const is None or num_upper_const >= 0:
13171314
if ctx.get_dtype(num_upper) != TensorProto.INT64:
13181315
num_upper = ctx.make_node("Cast", [num_upper], attr={'to': TensorProto.INT64}).output[0]
1319-
conditions.append(ctx.make_node("LessOrEqual", [idx_diff, num_upper]).output[0])
1316+
greater = ctx.make_node("Greater", [idx_diff, num_upper]).output[0]
1317+
less_or_equal = ctx.make_node("Not", [greater]).output[0]
1318+
conditions.append(less_or_equal)
13201319
if num_lower_const is None or num_lower_const >= 0:
13211320
if ctx.get_dtype(num_lower) != TensorProto.INT64:
13221321
num_lower = ctx.make_node("Cast", [num_lower], attr={'to': TensorProto.INT64}).output[0]
13231322
num_lower_neg = ctx.make_node("Neg", [num_lower]).output[0]
1324-
conditions.append(ctx.make_node("LessOrEqual", [num_lower_neg, idx_diff]).output[0])
1323+
greater = ctx.make_node("Greater", [num_lower_neg, idx_diff]).output[0]
1324+
less_or_equal = ctx.make_node("Not", [greater]).output[0]
1325+
conditions.append(less_or_equal)
13251326
if len(conditions) == 0:
13261327
node.type = "Identity"
13271328
ctx.replace_inputs(node, [data])

tf2onnx/tf_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def is_huge_shape(x):
221221
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
222222
progress = True
223223
can_fold = node.type not in ['Enter', 'Placeholder', 'PlaceholderWithDefault']
224+
can_fold = can_fold and not node.type.startswith('Random')
224225
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
225226
# We can only fold nodes with a single output
226227
can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values

0 commit comments

Comments
 (0)