Skip to content

Commit 30ec084

Browse files
Fix MaxpoolWithArgmax (#1451)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 57ef758 commit 30ec084

File tree

2 files changed

+71
-7
lines changed

2 files changed

+71
-7
lines changed

tests/test_backend.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,12 @@ def get_conv_getdata(kind=1):
164164

165165
def get_maxpoolwithargmax_getdata():
166166
data = [
167-
('SAME', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]),
168-
('SAME', [1, 5, 5, 1], [1, 4, 4, 1], [1, 2, 2, 1]),
169-
('SAME', [1, 10, 5, 1], [1, 2, 2, 1], [1, 2, 2, 1]),
170-
('SAME', [1, 10, 5, 1], [1, 4, 4, 1], [1, 1, 1, 1]),
171-
('VALID', [1, 3, 3, 1], [1, 3, 3, 1], [1, 2, 2, 1]),
172-
('VALID', [1, 5, 5, 1], [1, 4, 4, 1], [1, 2, 2, 1]),
167+
('SAME', [1, 3, 3, 2], [1, 3, 3, 1], [1, 2, 2, 1]),
168+
('SAME', [2, 5, 5, 3], [1, 4, 4, 1], [1, 2, 2, 1]),
169+
('SAME', [2, 10, 5, 1], [1, 2, 2, 1], [1, 2, 2, 1]),
170+
('SAME', [2, 10, 5, 3], [1, 4, 4, 1], [1, 1, 1, 1]),
171+
('VALID', [2, 3, 3, 3], [1, 3, 3, 1], [1, 2, 2, 1]),
172+
('VALID', [2, 5, 5, 3], [1, 4, 4, 1], [1, 2, 2, 1]),
173173
]
174174
for idx, v in enumerate(data):
175175
yield (idx,) + v
@@ -3738,13 +3738,41 @@ def func(x):
37383738
def test_maxpoolwithargmax(self):
37393739
for p in get_maxpoolwithargmax_getdata():
37403740
_, padding, x_shape, ksize, strides = p
3741-
x_val = make_xval(x_shape)
3741+
x_val = np.random.uniform(0, 10, x_shape)
37423742
def func(x):
37433743
mp = tf.nn.max_pool_with_argmax(x, ksize, strides, padding=padding)
37443744
return tf.identity(mp[0], name=_TFOUTPUT), tf.identity(mp[1], name=_TFOUTPUT1)
37453745
self.logger.debug(str(p))
37463746
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val})
37473747

3748+
@check_tf_min_version("1.13")
3749+
@check_opset_min_version(11, "MaxPoolWithArgmax")
3750+
def test_maxpoolwithargmax_batch_in_index(self):
3751+
padding = 'SAME'
3752+
x_shape = [2, 10, 5, 3]
3753+
ksize = [1, 4, 4, 1]
3754+
strides = [1, 1, 1, 1]
3755+
x_val = np.random.uniform(0, 10, x_shape)
3756+
def func(x):
3757+
mp = tf.nn.max_pool_with_argmax(x, ksize, strides, padding=padding, include_batch_in_index=True)
3758+
return tf.identity(mp[0], name=_TFOUTPUT), tf.identity(mp[1], name=_TFOUTPUT1)
3759+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val})
3760+
3761+
@check_tf_min_version("1.13")
3762+
@check_opset_min_version(11, "MaxPoolWithArgmax")
3763+
def test_maxpoolwithargmax_unknown_c(self):
3764+
padding = 'SAME'
3765+
x_shape = [2, 10, 5, 1]
3766+
ksize = [1, 4, 4, 1]
3767+
strides = [1, 1, 1, 1]
3768+
x_val = np.random.uniform(0, 10, x_shape)
3769+
s_val = np.array([2, 10, 5, 4], np.int64)
3770+
def func(x, s):
3771+
x = tf.broadcast_to(x, s)
3772+
mp = tf.nn.max_pool_with_argmax(x, ksize, strides, padding=padding, include_batch_in_index=True)
3773+
return tf.identity(mp[0], name=_TFOUTPUT), tf.identity(mp[1], name=_TFOUTPUT1)
3774+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: s_val})
3775+
37483776
@check_opset_min_version(10, "Selu")
37493777
def test_selu(self):
37503778
x_val = np.random.random_sample([3]).astype(np.float32)

tf2onnx/onnx_opset/nn.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,42 @@ def version_8(cls, ctx, node, **kwargs):
694694
# The input data_format is NHWC for TF MaxPoolWithArgmax
695695
node.set_attr("data_format", "NHWC")
696696

697+
# Convert indices from NCHW to NHWC format
698+
input_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
699+
input_shape_guess = ctx.get_shape(node.input[0])
700+
n, h, w, c = ctx.make_node("Split", [input_shape], attr={'axis': 0}, output_count=4).output
701+
hw = ctx.make_node("Mul", [h, w]).output[0]
702+
chw = ctx.make_node("Mul", [hw, c]).output[0]
703+
consumers = ctx.find_output_consumers(node.output[1])
704+
if ctx.opset >= 10:
705+
xy = ctx.make_node("Mod", [node.output[1], hw]).output[0]
706+
else:
707+
xy_div = ctx.make_node("Div", [node.output[1], hw]).output[0]
708+
xy_mul = ctx.make_node("Mul", [xy_div, hw]).output[0]
709+
xy = ctx.make_node("Sub", [node.output[1], xy_mul]).output[0]
710+
xy_scale_c = ctx.make_node("Mul", [xy, c]).output[0]
711+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
712+
const_one = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
713+
if input_shape_guess is not None and input_shape_guess[3] > 0:
714+
c_range_np = np.arange(input_shape_guess[3], dtype=np.int64)
715+
c_range = ctx.make_const(utils.make_name("c_range"), c_range_np).output[0]
716+
else:
717+
utils.make_sure(ctx.opset >= 11, "opset 11 required for MaxPoolWithArgmax with non-const num channels")
718+
c_sq = GraphBuilder(ctx).make_squeeze({'data': c, 'axes': [0]})
719+
c_range = ctx.make_node("Range", [const_zero, c_sq, const_one]).output[0]
720+
xyc = ctx.make_node("Add", [xy_scale_c, c_range]).output[0]
721+
single_batch = input_shape_guess is not None and input_shape_guess[0] == 1
722+
if node.get_attr_value('include_batch_in_index', False) and not single_batch:
723+
utils.make_sure(ctx.opset >= 11, "opset 11 required for MaxPoolWithArgmax with include_batch_in_index")
724+
n_sq = GraphBuilder(ctx).make_squeeze({'data': n, 'axes': [0]})
725+
n_range = ctx.make_node("Range", [const_zero, n_sq, const_one]).output[0]
726+
n_range_unsq = GraphBuilder(ctx).make_unsqueeze({'data': n_range, 'axes': [1, 2, 3]})
727+
n_range_scale = ctx.make_node("Mul", [n_range_unsq, chw]).output[0]
728+
result = ctx.make_node("Add", [xyc, n_range_scale]).output[0]
729+
else:
730+
result = xyc
731+
ctx.replace_all_inputs(node.output[1], result, ops=consumers)
732+
697733
add_padding(ctx, node, kernel_shape, strides)
698734
conv_convert_inputs(ctx, node, with_kernel=False, input_indices=[0], output_indices=[0, 1])
699735

0 commit comments

Comments
 (0)