Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,14 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)

def test_reducemax_global_max_pool(self):
for keepdims in [True, False]:
x_val = make_xval((2, 3, 4, 5, 6))
def func(x):
x_ = tf.reduce_max(x, axis=[2, 3, 4], keepdims=keepdims)
return tf.add(x_, 0, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@skip_caffe2_backend()
def test_reduceprod(self):
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
Expand All @@ -1805,6 +1813,14 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

def test_reducemean_global_avg_pool(self):
for keepdims in [True, False]:
x_val = make_xval((2, 3, 4, 5))
def func(x):
x_ = tf.reduce_mean(x, axis=[2, 3], keepdims=keepdims)
return tf.add(x_, 0, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@skip_caffe2_backend()
@check_onnxruntime_incompatibility("Pow")
def test_pow_scalar(self):
Expand Down
2 changes: 2 additions & 0 deletions tf2onnx/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .upsample_optimizer import UpsampleOptimizer
from .const_dequantize_optimizer import ConstDequantizeOptimizer
from .reshape_optimizer import ReshapeOptimizer
from .global_pool_optimizer import GlobalPoolOptimizer
from .. import logging

# optimizer sequence need to be considered carefully
Expand All @@ -33,6 +34,7 @@
("reshape_optimizer", ReshapeOptimizer),
("remove_identity", IdentityOptimizer),
("remove_back_to_back", BackToBackOptimizer),
("global_pool_optimizer", GlobalPoolOptimizer),
])


Expand Down
60 changes: 60 additions & 0 deletions tf2onnx/optimizer/global_pool_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0


"""global pool optimizer
Replaces ReduceMean and ReduceMax patterns with GlobalAveragePool and GlobalMaxPool
"""

from onnx import TensorProto
from tf2onnx.graph_builder import GraphBuilder
from .optimizer_base import GraphOptimizerBase

# pylint: disable=logging-not-lazy,unused-argument,missing-docstring


class GlobalPoolOptimizer(GraphOptimizerBase):

def __init__(self): # pylint: disable=useless-super-delegation
super(GlobalPoolOptimizer, self).__init__()

def _optimize(self, graph):
return self._apply_optimization(graph, self._optimize_at_current_graph_level)

def _optimize_at_current_graph_level(self, graph):
graph_changed = True
while graph_changed:
graph_changed = False
ops = graph.get_nodes()
for op in ops:
if op.type in ["ReduceMean", "ReduceMax"] and self._optimize_reduce(op, graph):
graph_changed = True
self.graph_been_opt = True
return graph

def _optimize_reduce(self, node, graph):
if graph.get_dtype(node.output[0]) not in [TensorProto.FLOAT, TensorProto.DOUBLE]:
return False
if node.output[0] in graph.outputs:
# Replacement is unsafe
return False
axes = node.get_attr_value('axes')
inp_rank = graph.get_rank(node.input[0])
if inp_rank is None:
return False
if axes != list(range(2, inp_rank)):
return False
op_map = {"ReduceMean": "GlobalAveragePool", "ReduceMax": "GlobalMaxPool"}
node.type = op_map[node.type]
del node.attr['axes']
if not node.get_attr_value('keepdims', True):
out_shapes = node.output_shapes
out_dtypes = node.output_dtypes
new_out_shape = graph.get_shape(node.input[0])[:2] + [1] * len(axes)
graph.set_shape(node.output[0], new_out_shape)
squeeze_node = GraphBuilder(graph).make_squeeze(
{'data': node.output[0], 'axes': axes}, shapes=out_shapes, dtypes=out_dtypes,
return_node=True, op_name_scope=node.name)
graph.insert_node_on_output(squeeze_node, node.output[0])
if 'keepdims' in node.attr:
del node.attr['keepdims']
return True