diff --git a/tests/keras2onnx_applications/nightly_build/test_nlp.py b/tests/keras2onnx_applications/nightly_build/test_nlp.py index c3274da6d..b64b0ad5d 100644 --- a/tests/keras2onnx_applications/nightly_build/test_nlp.py +++ b/tests/keras2onnx_applications/nightly_build/test_nlp.py @@ -5,7 +5,7 @@ import unittest import mock_keras2onnx import numpy as np -from mock_keras2onnx.proto import keras, is_tf_keras +from mock_keras2onnx.proto import keras, is_tensorflow_older_than from os.path import dirname, abspath sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/')) from test_utils import run_onnx_runtime @@ -91,6 +91,7 @@ def test_babi_rnn(self): expected = model.predict([x, y]) self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, {model.input_names[0]: x, model.input_names[1]: y}, expected, self.model_files)) + @unittest.skipIf(is_tensorflow_older_than('2.0.0'), "Result is slightly different in tf1") @unittest.skipIf(get_maximum_opset_supported() < 9, "None seq_length LSTM is not supported before opset 9.") def test_imdb_bidirectional_lstm(self): diff --git a/tests/test_backend.py b/tests/test_backend.py index a5f3cf3cf..e1042648c 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -740,6 +740,29 @@ def func(x): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True, graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False)) + @check_tf_min_version("1.15") + @skip_tf_cpu("only tf_gpu can run conv2d with NCHW format") + def test_conv2d_biasadd_rewriter(self): + x_shape = [2, 3, 32, 16] + x_val = make_xval(x_shape) + def func(x): + middles = tf.keras.layers.ZeroPadding2D( + padding=(0, 4), + data_format="channels_first", + name="padding" + )(x) + t = tf.keras.layers.Conv2D( + filters=768, + kernel_size=3, + strides=1, + use_bias=True, + data_format="channels_first", + name="conv2d" + )(middles) + return tf.identity(t, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True, + graph_validator=lambda g: check_op_count(g, "Add", 0, disabled=False)) + @check_tf_min_version("1.15") def test_conv2d_dilations_rewriter(self): x_shape = [2, 32, 16, 3] @@ -2353,6 +2376,9 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_tflite("tflite does not support uint32 if tf version <= 2.3.0") + @check_opset_min_version(6, "cast") + def test_cast_unit32(self): x_val = np.array([1, 2, 3, 4], dtype=np.uint32).reshape((2, 2)) def func(x): x_ = tf.cast(x, tf.uint64) diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 104ec9855..d2f65f812 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -373,6 +373,9 @@ def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, cu with tf.device("/cpu:0"): frozen_graph, initialized_tables = tf_loader.freeze_session(sess, input_names, output_names, get_tables=True) + with tf.Graph().as_default(): + tf.import_graph_def(frozen_graph, name="") + frozen_graph = tf_loader.tf_optimize(input_names, output_names, frozen_graph, False) model_proto, external_tensor_storage = _convert_common( frozen_graph, name=model.name, diff --git a/tf2onnx/rewriter/__init__.py b/tf2onnx/rewriter/__init__.py index 71664e60a..fc551a4ad 100644 --- a/tf2onnx/rewriter/__init__.py +++ b/tf2onnx/rewriter/__init__.py @@ -28,7 +28,6 @@ __all__ = [ "rewrite_cond", - "rewrite_conv2d_with_pad", "rewrite_dropout", "rewrite_eye", "rewrite_flatten", @@ -49,6 +48,7 @@ "rewrite_quantize_and_dequantize", "rewrite_layer_normalization", "rewrite_conv_dilations", + "rewrite_conv2d_with_pad", "rewrite_ragged_variant_shape", "rewriter_lstm_tf2", "rewrite_gru_tf2", diff --git a/tf2onnx/rewriter/conv2d_with_add_rewriter.py b/tf2onnx/rewriter/conv2d_with_add_rewriter.py index aa941d75b..0f8d87d02 100644 --- a/tf2onnx/rewriter/conv2d_with_add_rewriter.py +++ b/tf2onnx/rewriter/conv2d_with_add_rewriter.py @@ -13,31 +13,40 @@ # pylint: disable=missing-docstring def rewrite_biasadd_with_conv2d(g, ops): - pattern = \ + pattern1 = \ OpTypePattern('BiasAdd', name='biasadd', inputs=[ OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*']) - matcher = GraphMatcher(pattern) - match_results = list(matcher.match_ops(ops)) - for match in match_results: - biasadd = match.get_op('biasadd') - conv = match.get_op('conv') - - #backup the conv and biasadd values - conv_type = conv.type - conv_input = conv.input - conv_attr = conv.attr - dtype = g.get_dtype(conv.output[0]) - shape = g.get_shape(conv.output[0]) - conv_name = biasadd.name - conv_output = biasadd.output - conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] - - if len(g.find_output_consumers(conv.output[0])) > 1: - continue - # Remove the Conv and BiasAdd node - g.remove_node(conv.name) - g.remove_node(biasadd.name) - - g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, - shapes=[shape], dtypes=[dtype], skip_conversion=False) + pattern2 = \ + OpTypePattern('BiasAdd', name='biasadd', inputs=[ + OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=[ + '*', '*', '*']), '*'], allow_reorder=True) + + for pattern in [pattern1, pattern2]: + matcher = GraphMatcher(pattern) + match_results = list(matcher.match_ops(ops)) + for match in match_results: + biasadd = match.get_op('biasadd') + conv = match.get_op('conv') + + # Backup the conv and biasadd values + conv_type = conv.type + conv_input = conv.input + conv_attr = conv.attr + dtype = g.get_dtype(conv.output[0]) + shape = g.get_shape(conv.output[0]) + conv_name = biasadd.name + conv_output = biasadd.output + if pattern == pattern2: + conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]] + else: + conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] + + if len(g.find_output_consumers(conv.output[0])) > 1: + continue + # Remove the Conv and BiasAdd node + g.remove_node(conv.name) + g.remove_node(biasadd.name) + + g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, + shapes=[shape], dtypes=[dtype], skip_conversion=False) return ops diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index f6699087d..1a827c87d 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -681,7 +681,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non rewrite_options = config.graph_options.rewrite_options config.graph_options.infer_shapes = True # TODO: if we turn on pruning, grappler removes some identities that the tf-1.x lstm rewriter - # depends on so for now don't turn this on. + # depends on so for now don't turn this on, fold_constant is always enabled now. rewrite_options.optimizers[:] = [ # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', 'constfold', 'function' diff --git a/tf2onnx/version.py b/tf2onnx/version.py index 40889b5d1..ff7beba94 100644 --- a/tf2onnx/version.py +++ b/tf2onnx/version.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 - -version = '1.8.0' -git_version = '24080398ff4793ed8aac028ffa4b714a4803d7fb' +version = '1.10.0' +git_version = '219e00c073f6e73fba7335630dcf1f96cc82c983'