Skip to content

Commit 7768115

Browse files
committed
Fix TF1-Keras Dilated Conv Export
update patch skip cpu test for conv3d ncdhw skip CPU for Conv2D NCHW update Signed-off-by: Lei Mao <[email protected]>
1 parent 4245d8d commit 7768115

File tree

2 files changed

+113
-13
lines changed

2 files changed

+113
-13
lines changed

tests/test_backend.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
317

418
"""Unit tests using onnx backends."""
519

@@ -727,7 +741,7 @@ def func(x):
727741
onnx_feed_dict={_INPUT: x_val_for_onnx})
728742

729743
@skip_tflite("TFlite adds ops that obscure pattern")
730-
@check_tf_min_version("2.0")
744+
@check_tf_min_version("1.15")
731745
def test_conv1d_dilations_rewriter(self):
732746
x_shape = [2, 32, 3]
733747
x_val = make_xval(x_shape)
@@ -740,7 +754,7 @@ def func(x):
740754
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
741755
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
742756

743-
@check_tf_min_version("2.0")
757+
@check_tf_min_version("1.15")
744758
def test_conv2d_dilations_rewriter(self):
745759
x_shape = [2, 32, 16, 3]
746760
x_val = make_xval(x_shape)
@@ -760,7 +774,39 @@ def func(x):
760774
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
761775
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
762776

763-
@check_tf_min_version("2.0")
777+
@check_tf_min_version("1.15")
778+
@skip_tf_cpu("only tf_gpu can run conv2d with NCHW format")
779+
def test_nchw_conv2d_dilations_rewriter(self):
780+
x_shape = [2, 3, 32, 16]
781+
x_val = make_xval(x_shape)
782+
for p in ['SAME', 'VALID']:
783+
def func(x):
784+
t = tf.keras.layers.Conv2D(
785+
filters=768,
786+
kernel_size=3,
787+
dilation_rate=3,
788+
padding=p,
789+
data_format='channels_first'
790+
)
791+
t.build(x_shape)
792+
y = t.call(x)
793+
return tf.identity(y, name=_TFOUTPUT)
794+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
795+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
796+
def func(x):
797+
t = tf.keras.layers.DepthwiseConv2D(
798+
kernel_size=3,
799+
dilation_rate=3,
800+
padding=p,
801+
data_format='channels_first'
802+
)
803+
t.build(x_shape)
804+
y = t.call(x)
805+
return tf.identity(y, name=_TFOUTPUT)
806+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
807+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
808+
809+
@check_tf_min_version("1.15")
764810
@skip_tflite("TFlite adds ops that obscure pattern")
765811
@allow_missing_shapes("Rewriting makes some shapes known")
766812
def test_conv2d_dilations_rewriter_unknown_shape(self):
@@ -776,7 +822,30 @@ def func():
776822
as_session=True, premade_placeholders=True,
777823
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
778824

779-
@check_tf_min_version("2.0")
825+
@check_tf_min_version("1.15")
826+
@skip_tflite("TFlite adds ops that obscure pattern")
827+
@skip_tf_cpu("only tf_gpu can run conv2d with NCHW format")
828+
@allow_missing_shapes("Rewriting makes some shapes known")
829+
def test_nchw_conv2d_dilations_rewriter_unknown_shape(self):
830+
x_shape = [2, 3, 32, 16]
831+
x_val = make_xval(x_shape)
832+
def func():
833+
x = tf_placeholder(tf.float32, [2, 3, None, None], name=_TFINPUT)
834+
t = tf.keras.layers.Conv2D(
835+
filters=768,
836+
kernel_size=3,
837+
dilation_rate=3,
838+
padding="VALID",
839+
data_format='channels_first'
840+
)
841+
t.build(x_shape)
842+
y = t.call(x)
843+
return tf.identity(y, name=_TFOUTPUT)
844+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2,
845+
as_session=True, premade_placeholders=True,
846+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
847+
848+
@check_tf_min_version("1.15")
780849
def test_conv3d_dilations_rewriter(self):
781850
x_shape = [2, 32, 16, 8, 3]
782851
x_val = make_xval(x_shape)
@@ -789,6 +858,26 @@ def func(x):
789858
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
790859
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
791860

861+
@check_tf_min_version("1.15")
862+
@skip_tf_cpu("only tf_gpu can run conv3d with NCDHW format")
863+
def test_ncdhw_conv3d_dilations_rewriter(self):
864+
x_shape = [2, 3, 32, 16, 8]
865+
x_val = make_xval(x_shape)
866+
for p in ['SAME', 'VALID']:
867+
def func(x):
868+
t = tf.keras.layers.Conv3D(
869+
filters=768,
870+
kernel_size=3,
871+
dilation_rate=3,
872+
padding=p,
873+
data_format='channels_first'
874+
)
875+
t.build(x_shape)
876+
y = t.call(x)
877+
return tf.identity(y, name=_TFOUTPUT)
878+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
879+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
880+
792881
@skip_tf2("Uses tf.layers")
793882
def test_conv1d_tf1_dilations_rewriter(self):
794883
x_shape = [2, 32, 3]

tf2onnx/rewriter/conv_dilations_rewriter.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
317

418
"""
519
tf2onnx.rewriter.conv_dilations_rewriter - Rewrites the patten used to represent dilations
@@ -13,6 +27,7 @@
1327

1428

1529
def rewrite_conv_dilations(g, ops):
30+
1631
pattern1 = \
1732
OpTypePattern("BatchToSpaceND", name="batch_to_space", inputs=[
1833
OpTypePattern("DepthwiseConv2dNative|Conv2D|Conv3D", name="conv", inputs=[
@@ -67,14 +82,7 @@ def rewrite_conv_dilations(g, ops):
6782
if block_shape1 != block_shape2:
6883
continue
6984
ndims = 2 if is_conv_1d else len(block_shape1)
70-
data_format = b"NHWC" if ndims == 2 else b"NDHWC"
71-
ones = [1] * (ndims + 2)
72-
if conv.get_attr_value("dilations", ones) != ones:
73-
continue
74-
if conv.get_attr_value("strides", ones) != ones:
75-
continue
76-
if conv.get_attr_value("data_format", data_format) != data_format:
77-
continue
85+
7886
if conv.get_attr_value("padding") != b"VALID":
7987
continue
8088

@@ -114,7 +122,10 @@ def rewrite_conv_dilations(g, ops):
114122
g.copy_shape(batch_to_space.output[0], conv.output[0])
115123
g.replace_all_inputs(batch_to_space.output[0], conv.output[0])
116124

117-
conv.set_attr("dilations", [1] + block_shape1 + [1])
125+
if conv.get_attr_value("data_format") in [b"NCHW", b"NCDHW"]:
126+
conv.set_attr("dilations", [1] + block_shape1)
127+
else:
128+
conv.set_attr("dilations", [1] + block_shape1 + [1])
118129
conv.set_attr("padding", pad_mode)
119130
if pad_mode == "EXPLICIT":
120131
conv.set_attr("explicit_paddings", base_pad_flat)

0 commit comments

Comments
 (0)