@@ -164,12 +164,12 @@ def get_conv_getdata(kind=1):
164
164
165
165
def get_maxpoolwithargmax_getdata ():
166
166
data = [
167
- ('SAME' , [1 , 3 , 3 , 1 ], [1 , 3 , 3 , 1 ], [1 , 2 , 2 , 1 ]),
168
- ('SAME' , [1 , 5 , 5 , 1 ], [1 , 4 , 4 , 1 ], [1 , 2 , 2 , 1 ]),
169
- ('SAME' , [1 , 10 , 5 , 1 ], [1 , 2 , 2 , 1 ], [1 , 2 , 2 , 1 ]),
170
- ('SAME' , [1 , 10 , 5 , 1 ], [1 , 4 , 4 , 1 ], [1 , 1 , 1 , 1 ]),
171
- ('VALID' , [1 , 3 , 3 , 1 ], [1 , 3 , 3 , 1 ], [1 , 2 , 2 , 1 ]),
172
- ('VALID' , [1 , 5 , 5 , 1 ], [1 , 4 , 4 , 1 ], [1 , 2 , 2 , 1 ]),
167
+ ('SAME' , [1 , 3 , 3 , 2 ], [1 , 3 , 3 , 1 ], [1 , 2 , 2 , 1 ]),
168
+ ('SAME' , [2 , 5 , 5 , 3 ], [1 , 4 , 4 , 1 ], [1 , 2 , 2 , 1 ]),
169
+ ('SAME' , [2 , 10 , 5 , 1 ], [1 , 2 , 2 , 1 ], [1 , 2 , 2 , 1 ]),
170
+ ('SAME' , [2 , 10 , 5 , 3 ], [1 , 4 , 4 , 1 ], [1 , 1 , 1 , 1 ]),
171
+ ('VALID' , [2 , 3 , 3 , 3 ], [1 , 3 , 3 , 1 ], [1 , 2 , 2 , 1 ]),
172
+ ('VALID' , [2 , 5 , 5 , 3 ], [1 , 4 , 4 , 1 ], [1 , 2 , 2 , 1 ]),
173
173
]
174
174
for idx , v in enumerate (data ):
175
175
yield (idx ,) + v
@@ -3738,13 +3738,41 @@ def func(x):
3738
3738
def test_maxpoolwithargmax (self ):
3739
3739
for p in get_maxpoolwithargmax_getdata ():
3740
3740
_ , padding , x_shape , ksize , strides = p
3741
- x_val = make_xval ( x_shape )
3741
+ x_val = np . random . uniform ( 0 , 10 , x_shape )
3742
3742
def func (x ):
3743
3743
mp = tf .nn .max_pool_with_argmax (x , ksize , strides , padding = padding )
3744
3744
return tf .identity (mp [0 ], name = _TFOUTPUT ), tf .identity (mp [1 ], name = _TFOUTPUT1 )
3745
3745
self .logger .debug (str (p ))
3746
3746
self ._run_test_case (func , [_OUTPUT , _OUTPUT1 ], {_INPUT : x_val })
3747
3747
3748
+ @check_tf_min_version ("1.13" )
3749
+ @check_opset_min_version (11 , "MaxPoolWithArgmax" )
3750
+ def test_maxpoolwithargmax_batch_in_index (self ):
3751
+ padding = 'SAME'
3752
+ x_shape = [2 , 10 , 5 , 3 ]
3753
+ ksize = [1 , 4 , 4 , 1 ]
3754
+ strides = [1 , 1 , 1 , 1 ]
3755
+ x_val = np .random .uniform (0 , 10 , x_shape )
3756
+ def func (x ):
3757
+ mp = tf .nn .max_pool_with_argmax (x , ksize , strides , padding = padding , include_batch_in_index = True )
3758
+ return tf .identity (mp [0 ], name = _TFOUTPUT ), tf .identity (mp [1 ], name = _TFOUTPUT1 )
3759
+ self ._run_test_case (func , [_OUTPUT , _OUTPUT1 ], {_INPUT : x_val })
3760
+
3761
+ @check_tf_min_version ("1.13" )
3762
+ @check_opset_min_version (11 , "MaxPoolWithArgmax" )
3763
+ def test_maxpoolwithargmax_unknown_c (self ):
3764
+ padding = 'SAME'
3765
+ x_shape = [2 , 10 , 5 , 1 ]
3766
+ ksize = [1 , 4 , 4 , 1 ]
3767
+ strides = [1 , 1 , 1 , 1 ]
3768
+ x_val = np .random .uniform (0 , 10 , x_shape )
3769
+ s_val = np .array ([2 , 10 , 5 , 4 ], np .int64 )
3770
+ def func (x , s ):
3771
+ x = tf .broadcast_to (x , s )
3772
+ mp = tf .nn .max_pool_with_argmax (x , ksize , strides , padding = padding , include_batch_in_index = True )
3773
+ return tf .identity (mp [0 ], name = _TFOUTPUT ), tf .identity (mp [1 ], name = _TFOUTPUT1 )
3774
+ self ._run_test_case (func , [_OUTPUT , _OUTPUT1 ], {_INPUT : x_val , _INPUT1 : s_val })
3775
+
3748
3776
@check_opset_min_version (10 , "Selu" )
3749
3777
def test_selu (self ):
3750
3778
x_val = np .random .random_sample ([3 ]).astype (np .float32 )
0 commit comments