From 1a5fc521c642638d1ad582533ec2c2484191c0d0 Mon Sep 17 00:00:00 2001 From: charlieguo0307 <869425298@qq.com> Date: Fri, 14 Aug 2020 15:36:25 +0800 Subject: [PATCH] fix NCHW converting bug when in pad_rewriter --- tf2onnx/rewriter/conv2d_with_pad_rewriter.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tf2onnx/rewriter/conv2d_with_pad_rewriter.py b/tf2onnx/rewriter/conv2d_with_pad_rewriter.py index a1e72eb56..ad4a8f49c 100644 --- a/tf2onnx/rewriter/conv2d_with_pad_rewriter.py +++ b/tf2onnx/rewriter/conv2d_with_pad_rewriter.py @@ -43,10 +43,16 @@ def rewrite_conv2d_with_pad(g, ops): logger.debug("merge pad [%s] into conv [%s]", pad.name, conv.name) paddings_val = np.array(paddings.get_tensor_value()) # can't pad on batch or channel dimensions - if np.any(paddings_val[0]) or np.any(paddings_val[3]): - continue + data_format = conv.get_attr("data_format").s.decode("utf-8") + if data_format == "NHWC": + if np.any(paddings_val[0]) or np.any(paddings_val[3]): + continue + paddings_val = paddings_val[1:3] + else: + if np.any(paddings_val[0]) or np.any(paddings_val[1]): + continue + paddings_val = paddings_val[2:4] - paddings_val = paddings_val[1:3] paddings_val = paddings_val.transpose().flatten() g.replace_input(conv, conv.input[0], pad.input[0]) # convert Conv2D