1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
3
17
4
18
"""Unit tests using onnx backends."""
5
19
@@ -727,7 +741,7 @@ def func(x):
727
741
onnx_feed_dict = {_INPUT : x_val_for_onnx })
728
742
729
743
@skip_tflite ("TFlite adds ops that obscure pattern" )
730
- @check_tf_min_version ("2.0 " )
744
+ @check_tf_min_version ("1.15 " )
731
745
def test_conv1d_dilations_rewriter (self ):
732
746
x_shape = [2 , 32 , 3 ]
733
747
x_val = make_xval (x_shape )
@@ -740,7 +754,7 @@ def func(x):
740
754
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
741
755
graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
742
756
743
- @check_tf_min_version ("2.0 " )
757
+ @check_tf_min_version ("1.15 " )
744
758
def test_conv2d_dilations_rewriter (self ):
745
759
x_shape = [2 , 32 , 16 , 3 ]
746
760
x_val = make_xval (x_shape )
@@ -760,7 +774,39 @@ def func(x):
760
774
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
761
775
graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
762
776
763
- @check_tf_min_version ("2.0" )
777
+ @check_tf_min_version ("1.15" )
778
+ @skip_tf_cpu ("only tf_gpu can run conv2d with NCHW format" )
779
+ def test_nchw_conv2d_dilations_rewriter (self ):
780
+ x_shape = [2 , 3 , 32 , 16 ]
781
+ x_val = make_xval (x_shape )
782
+ for p in ['SAME' , 'VALID' ]:
783
+ def func (x ):
784
+ t = tf .keras .layers .Conv2D (
785
+ filters = 768 ,
786
+ kernel_size = 3 ,
787
+ dilation_rate = 3 ,
788
+ padding = p ,
789
+ data_format = 'channels_first'
790
+ )
791
+ t .build (x_shape )
792
+ y = t .call (x )
793
+ return tf .identity (y , name = _TFOUTPUT )
794
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
795
+ graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
796
+ def func (x ):
797
+ t = tf .keras .layers .DepthwiseConv2D (
798
+ kernel_size = 3 ,
799
+ dilation_rate = 3 ,
800
+ padding = p ,
801
+ data_format = 'channels_first'
802
+ )
803
+ t .build (x_shape )
804
+ y = t .call (x )
805
+ return tf .identity (y , name = _TFOUTPUT )
806
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
807
+ graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
808
+
809
+ @check_tf_min_version ("1.15" )
764
810
@skip_tflite ("TFlite adds ops that obscure pattern" )
765
811
@allow_missing_shapes ("Rewriting makes some shapes known" )
766
812
def test_conv2d_dilations_rewriter_unknown_shape (self ):
@@ -776,7 +822,30 @@ def func():
776
822
as_session = True , premade_placeholders = True ,
777
823
graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
778
824
779
- @check_tf_min_version ("2.0" )
825
+ @check_tf_min_version ("1.15" )
826
+ @skip_tflite ("TFlite adds ops that obscure pattern" )
827
+ @skip_tf_cpu ("only tf_gpu can run conv2d with NCHW format" )
828
+ @allow_missing_shapes ("Rewriting makes some shapes known" )
829
+ def test_nchw_conv2d_dilations_rewriter_unknown_shape (self ):
830
+ x_shape = [2 , 3 , 32 , 16 ]
831
+ x_val = make_xval (x_shape )
832
+ def func ():
833
+ x = tf_placeholder (tf .float32 , [2 , 3 , None , None ], name = _TFINPUT )
834
+ t = tf .keras .layers .Conv2D (
835
+ filters = 768 ,
836
+ kernel_size = 3 ,
837
+ dilation_rate = 3 ,
838
+ padding = "VALID" ,
839
+ data_format = 'channels_first'
840
+ )
841
+ t .build (x_shape )
842
+ y = t .call (x )
843
+ return tf .identity (y , name = _TFOUTPUT )
844
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 ,
845
+ as_session = True , premade_placeholders = True ,
846
+ graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
847
+
848
+ @check_tf_min_version ("1.15" )
780
849
def test_conv3d_dilations_rewriter (self ):
781
850
x_shape = [2 , 32 , 16 , 8 , 3 ]
782
851
x_val = make_xval (x_shape )
@@ -789,6 +858,26 @@ def func(x):
789
858
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
790
859
graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
791
860
861
+ @check_tf_min_version ("1.15" )
862
+ @skip_tf_cpu ("only tf_gpu can run conv3d with NCDHW format" )
863
+ def test_ncdhw_conv3d_dilations_rewriter (self ):
864
+ x_shape = [2 , 3 , 32 , 16 , 8 ]
865
+ x_val = make_xval (x_shape )
866
+ for p in ['SAME' , 'VALID' ]:
867
+ def func (x ):
868
+ t = tf .keras .layers .Conv3D (
869
+ filters = 768 ,
870
+ kernel_size = 3 ,
871
+ dilation_rate = 3 ,
872
+ padding = p ,
873
+ data_format = 'channels_first'
874
+ )
875
+ t .build (x_shape )
876
+ y = t .call (x )
877
+ return tf .identity (y , name = _TFOUTPUT )
878
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 , atol = 1e-2 , as_session = True ,
879
+ graph_validator = lambda g : check_op_count (g , "Reshape" , 0 , disabled = False ))
880
+
792
881
@skip_tf2 ("Uses tf.layers" )
793
882
def test_conv1d_tf1_dilations_rewriter (self ):
794
883
x_shape = [2 , 32 , 3 ]
0 commit comments