From 0338e1ebe7e931e80badb74fd69eaff880d058e8 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Mon, 26 Oct 2020 19:10:12 -0400 Subject: [PATCH] Fixed half_pixel_centers for resize_nearest_neighbor Signed-off-by: Tom Wildenhain --- tests/test_backend.py | 25 +++++++++++++++++++++++++ tf2onnx/onnx_opset/nn.py | 5 ++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 449741f8e..a7f3f89dd 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2101,6 +2101,19 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) _ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_caffe2_backend() + @check_tf_min_version("1.14") + @check_opset_min_version(11, "coordinate_transformation_mode attr") + def test_resize_bilinear_half_pixel_centers(self): + x_shape = [1, 15, 20, 2] + x_new_size = [30, 40] + x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape) + def func(x): + x_new_size_ = tf.constant(x_new_size) + x_ = resize_bilinear(x, x_new_size_, half_pixel_centers=True) + return tf.identity(x_, name=_TFOUTPUT) + _ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(9, "resize_bilinear") def test_resize_bilinear_with_non_const(self): x_shape = [3, 10, 8, 5] @@ -2144,6 +2157,18 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) _ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_tf_min_version("1.14") + @check_opset_min_version(11, "coordinate_transformation_mode attr") + def test_resize_nearest_neighbor_half_pixel_centers(self): + x_shape = [1, 10, 20, 2] + x_new_size = [20, 40] + x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape) + def func(x): + x_new_size_ = tf.constant(x_new_size) + x_ = resize_nearest_neighbor(x, x_new_size_, half_pixel_centers=True) + return tf.identity(x_, name=_TFOUTPUT) + _ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(9, "fill") def test_fill_float32(self): x_shape = [1, 15, 20, 2] diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 8e6c1f827..546313009 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -903,7 +903,10 @@ def version_11(cls, ctx, node, **kwargs): if "align_corners" in node.attr and node.attr["align_corners"].i: transformation_mode = "align_corners" if "half_pixel_centers" in node.attr and node.attr["half_pixel_centers"].i: - transformation_mode = "half_pixel" + if node.type == "ResizeNearestNeighbor": + transformation_mode = "tf_half_pixel_for_nn" + else: + transformation_mode = "half_pixel" resize = ctx.make_node("Resize", resize_inputs, attr={"mode": mode, "nearest_mode": "floor", "coordinate_transformation_mode": transformation_mode})