diff --git a/tests/test_backend.py b/tests/test_backend.py index 7403b272d..bdeef5fd6 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1790,6 +1790,14 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05) + def test_reducemax_global_max_pool(self): + for keepdims in [True, False]: + x_val = make_xval((2, 3, 4, 5, 6)) + def func(x): + x_ = tf.reduce_max(x, axis=[2, 3, 4], keepdims=keepdims) + return tf.add(x_, 0, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_caffe2_backend() def test_reduceprod(self): x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) @@ -1805,6 +1813,14 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + def test_reducemean_global_avg_pool(self): + for keepdims in [True, False]: + x_val = make_xval((2, 3, 4, 5)) + def func(x): + x_ = tf.reduce_mean(x, axis=[2, 3], keepdims=keepdims) + return tf.add(x_, 0, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @skip_caffe2_backend() @check_onnxruntime_incompatibility("Pow") def test_pow_scalar(self): diff --git a/tf2onnx/optimizer/__init__.py b/tf2onnx/optimizer/__init__.py index 79569405f..188702032 100644 --- a/tf2onnx/optimizer/__init__.py +++ b/tf2onnx/optimizer/__init__.py @@ -18,6 +18,7 @@ from .upsample_optimizer import UpsampleOptimizer from .const_dequantize_optimizer import ConstDequantizeOptimizer from .reshape_optimizer import ReshapeOptimizer +from .global_pool_optimizer import GlobalPoolOptimizer from .. import logging # optimizer sequence need to be considered carefully @@ -33,6 +34,7 @@ ("reshape_optimizer", ReshapeOptimizer), ("remove_identity", IdentityOptimizer), ("remove_back_to_back", BackToBackOptimizer), + ("global_pool_optimizer", GlobalPoolOptimizer), ]) diff --git a/tf2onnx/optimizer/global_pool_optimizer.py b/tf2onnx/optimizer/global_pool_optimizer.py new file mode 100644 index 000000000..b0b0640d6 --- /dev/null +++ b/tf2onnx/optimizer/global_pool_optimizer.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 + + +"""global pool optimizer + Replaces ReduceMean and ReduceMax patterns with GlobalAveragePool and GlobalMaxPool +""" + +from onnx import TensorProto +from tf2onnx.graph_builder import GraphBuilder +from .optimizer_base import GraphOptimizerBase + +# pylint: disable=logging-not-lazy,unused-argument,missing-docstring + + +class GlobalPoolOptimizer(GraphOptimizerBase): + + def __init__(self): # pylint: disable=useless-super-delegation + super(GlobalPoolOptimizer, self).__init__() + + def _optimize(self, graph): + return self._apply_optimization(graph, self._optimize_at_current_graph_level) + + def _optimize_at_current_graph_level(self, graph): + graph_changed = True + while graph_changed: + graph_changed = False + ops = graph.get_nodes() + for op in ops: + if op.type in ["ReduceMean", "ReduceMax"] and self._optimize_reduce(op, graph): + graph_changed = True + self.graph_been_opt = True + return graph + + def _optimize_reduce(self, node, graph): + if graph.get_dtype(node.output[0]) not in [TensorProto.FLOAT, TensorProto.DOUBLE]: + return False + if node.output[0] in graph.outputs: + # Replacement is unsafe + return False + axes = node.get_attr_value('axes') + inp_rank = graph.get_rank(node.input[0]) + if inp_rank is None: + return False + if axes != list(range(2, inp_rank)): + return False + op_map = {"ReduceMean": "GlobalAveragePool", "ReduceMax": "GlobalMaxPool"} + node.type = op_map[node.type] + del node.attr['axes'] + if not node.get_attr_value('keepdims', True): + out_shapes = node.output_shapes + out_dtypes = node.output_dtypes + new_out_shape = graph.get_shape(node.input[0])[:2] + [1] * len(axes) + graph.set_shape(node.output[0], new_out_shape) + squeeze_node = GraphBuilder(graph).make_squeeze( + {'data': node.output[0], 'axes': axes}, shapes=out_shapes, dtypes=out_dtypes, + return_node=True, op_name_scope=node.name) + graph.insert_node_on_output(squeeze_node, node.output[0]) + if 'keepdims' in node.attr: + del node.attr['keepdims'] + return True