@@ -1800,16 +1800,61 @@ def func():
1800
1800
# since results are random, compare the shapes only
1801
1801
self ._run_test_case (func , [_OUTPUT ], {}, check_value = False , check_shape = True )
1802
1802
1803
- @unittest .skip ("TF RandomUniformInt is not supported" )
1804
1803
def test_randomuniform_int (self ):
1805
1804
def func ():
1806
- shape = tf .constant ([2 , 3 ], name = "shape" )
1807
- x_ = random_uniform (shape , name = "rand" , dtype = tf .int32 , maxval = 10 )
1805
+ shape = tf .constant ([100 , 3 ], name = "shape" )
1806
+ x_ = random_uniform (shape , name = "rand" , dtype = tf .int32 , minval = 2 , maxval = 10 )
1808
1807
x_ = tf .identity (x_ , name = "output1" )
1809
1808
x_ = tf .identity (x_ , name = "output2" )
1810
1809
return tf .identity (x_ , name = _TFOUTPUT )
1811
1810
# since results are random, compare the shapes only
1812
- self ._run_test_case (func , [_OUTPUT ], {}, check_value = False , check_shape = True )
1811
+ g = self ._run_test_case (func , [_OUTPUT ], {}, check_value = False , check_shape = True )
1812
+ results = self .run_backend (g , [_OUTPUT ], {})
1813
+ numbers = set (results [0 ].flatten ())
1814
+ self .assertEqual (sorted (numbers ), list (range (2 , 10 )))
1815
+
1816
+ def test_randomuniform_int_nonconst_max (self ):
1817
+ m_val = np .array (8 , dtype = np .int32 )
1818
+ def func (m ):
1819
+ shape = tf .constant ([100 , 3 ], name = "shape" )
1820
+ x_ = random_uniform (shape , name = "rand" , dtype = tf .int32 , minval = 0 , maxval = m )
1821
+ x_ = tf .identity (x_ , name = "output1" )
1822
+ x_ = tf .identity (x_ , name = "output2" )
1823
+ return tf .identity (x_ , name = _TFOUTPUT )
1824
+ g = self ._run_test_case (func , [_OUTPUT ], {_INPUT : m_val }, check_value = False , check_shape = True )
1825
+ results = self .run_backend (g , [_OUTPUT ], {_INPUT : m_val })
1826
+ numbers = set (results [0 ].flatten ())
1827
+ self .assertEqual (sorted (numbers ), list (range (8 )))
1828
+
1829
+ def test_randomuniform_int_nonconst_min_max (self ):
1830
+ n_val = np .array (2 , dtype = np .int32 )
1831
+ m_val = np .array (10 , dtype = np .int32 )
1832
+ def func (n , m ):
1833
+ shape = tf .constant ([100 , 3 ], name = "shape" )
1834
+ x_ = random_uniform (shape , name = "rand" , dtype = tf .int32 , minval = n , maxval = m )
1835
+ x_ = tf .identity (x_ , name = "output1" )
1836
+ x_ = tf .identity (x_ , name = "output2" )
1837
+ return tf .identity (x_ , name = _TFOUTPUT )
1838
+ g = self ._run_test_case (func , [_OUTPUT ], {_INPUT : n_val , _INPUT1 : m_val }, check_value = False , check_shape = True )
1839
+ results = self .run_backend (g , [_OUTPUT ], {_INPUT : n_val , _INPUT1 : m_val })
1840
+ numbers = set (results [0 ].flatten ())
1841
+ self .assertEqual (sorted (numbers ), list (range (2 , 10 )))
1842
+
1843
+ @check_opset_min_version (9 , "RandomUniformLike" )
1844
+ def test_randomuniform_int_nonconst_min_max_shape (self ):
1845
+ n_val = np .array (2 , dtype = np .int32 )
1846
+ m_val = np .array (10 , dtype = np .int32 )
1847
+ s_val = np .array ([100 , 3 ], dtype = np .int64 )
1848
+ def func (n , m , s ):
1849
+ x_ = random_uniform (s , name = "rand" , dtype = tf .int32 , minval = n , maxval = m )
1850
+ x_ = tf .identity (x_ , name = "output1" )
1851
+ x_ = tf .identity (x_ , name = "output2" )
1852
+ return tf .identity (x_ , name = _TFOUTPUT )
1853
+ g = self ._run_test_case (func , [_OUTPUT ], {_INPUT : n_val , _INPUT1 : m_val , _INPUT2 : s_val },
1854
+ check_value = False , check_shape = True )
1855
+ results = self .run_backend (g , [_OUTPUT ], {_INPUT : n_val , _INPUT1 : m_val , _INPUT2 : s_val })
1856
+ numbers = set (results [0 ].flatten ())
1857
+ self .assertEqual (sorted (numbers ), list (range (2 , 10 )))
1813
1858
1814
1859
@skip_caffe2_backend ()
1815
1860
@check_opset_after_tf_version ("2.2" , 9 , "RandomUniform" )
@@ -2981,7 +3026,7 @@ def func(input_x):
2981
3026
2982
3027
@check_opset_min_version (11 , "CumSum" )
2983
3028
def test_matrix_band_part_3 (self ):
2984
- for low , high in [(- 1 , 3 ), (2 , 3 ), (4 , 3 ), (0 , - 1 ), (0 , 0 )]:
3029
+ for low , high in [(- 1 , 3 ), (2 , 3 ), (4 , 3 ), (0 , - 1 ), (0 , 0 ), ( - 1 , - 1 ) ]:
2985
3030
input_val = np .random .randint (0 , 666 , (10 , 15 )).astype (np .int32 )
2986
3031
def func (input_x ):
2987
3032
res = tf .linalg .band_part (input_x , low , high )
0 commit comments