Skip to content

Commit ba4ce36

Browse files
Merge pull request #1221 from xadupre/squ
[WIP] Support operators Squeeze/Unsqueeze for opset 13 (v2)
2 parents 1400d89 + 751fc5c commit ba4ce36

20 files changed

+545
-135
lines changed

tests/test_optimizers.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def test_transpose_with_identity(self):
318318
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
319319
model_proto, remaining_transpose_num=1)
320320

321+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
321322
def test_transpose_with_squeeze1(self):
322323
# squeeze the first dim
323324
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
@@ -335,6 +336,26 @@ def test_transpose_with_squeeze1(self):
335336
model_proto, remaining_transpose_num=1)
336337
self.check_transpose_perm(model_after_opt, [1, 2, 0])
337338

339+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
340+
def test_transpose_with_squeeze1_13(self):
341+
# squeeze the first dim
342+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
343+
axes = self._make_onnx_const(np.array([0], dtype=np.int64), "axes")
344+
node2 = helper.make_node("Squeeze", ["Y", "axes"], ["Z"], name="squeeze")
345+
346+
graph = helper.make_graph(
347+
[node1, node2, axes],
348+
"transpose_with_squeeze",
349+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 4, 5))],
350+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (4, 5, 3))],
351+
)
352+
353+
model_proto = self.make_model(graph, producer_name="onnx-tests")
354+
model_after_opt = self.run_transpose_compare(["Z"], {"X": np.random.randn(1, 3, 4, 5).astype(np.float32)},
355+
model_proto, remaining_transpose_num=1)
356+
self.check_transpose_perm(model_after_opt, [1, 2, 0])
357+
358+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
338359
def test_transpose_with_squeeze2(self):
339360
# squeeze the second dim
340361
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
@@ -352,6 +373,26 @@ def test_transpose_with_squeeze2(self):
352373
model_proto, remaining_transpose_num=1)
353374
self.check_transpose_perm(model_after_opt, [0, 2, 1])
354375

376+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
377+
def test_transpose_with_squeeze2_13(self):
378+
# squeeze the second dim
379+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
380+
axes = self._make_onnx_const(np.array([1], dtype=np.int64), "axes")
381+
node2 = helper.make_node("Squeeze", ["Y", "axes"], ["Z"], name="squeeze")
382+
383+
graph = helper.make_graph(
384+
[node1, node2, axes],
385+
"transpose_with_squeeze",
386+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (3, 4, 1, 5))],
387+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (3, 5, 4))],
388+
)
389+
390+
model_proto = self.make_model(graph, producer_name="onnx-tests")
391+
model_after_opt = self.run_transpose_compare(["Z"], {"X": np.random.randn(3, 4, 1, 5).astype(np.float32)},
392+
model_proto, remaining_transpose_num=1)
393+
self.check_transpose_perm(model_after_opt, [0, 2, 1])
394+
395+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
355396
def test_transpose_with_squeeze3(self):
356397
# squeeze the last dim
357398
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
@@ -368,6 +409,25 @@ def test_transpose_with_squeeze3(self):
368409
self.run_transpose_compare(["Z"], {"X": np.random.randn(3, 1, 4, 5).astype(np.float32)},
369410
model_proto, remaining_transpose_num=0)
370411

412+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
413+
def test_transpose_with_squeeze3_13(self):
414+
# squeeze the last dim
415+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
416+
axes = self._make_onnx_const(np.array([3], dtype=np.int64), "axes")
417+
node2 = helper.make_node("Squeeze", ["Y", "axes"], ["Z"], name="squeeze")
418+
419+
graph = helper.make_graph(
420+
[node1, node2, axes],
421+
"transpose_with_squeeze",
422+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (3, 1, 4, 5))],
423+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (3, 4, 5))],
424+
)
425+
426+
model_proto = self.make_model(graph, producer_name="onnx-tests")
427+
self.run_transpose_compare(["Z"], {"X": np.random.randn(3, 1, 4, 5).astype(np.float32)},
428+
model_proto, remaining_transpose_num=0)
429+
430+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
371431
def test_transpose_with_squeeze4(self):
372432
# squeeze the two dims
373433
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
@@ -384,6 +444,24 @@ def test_transpose_with_squeeze4(self):
384444
self.run_transpose_compare(["Z"], {"X": np.random.randn(3, 1, 1, 5).astype(np.float32)},
385445
model_proto, remaining_transpose_num=0)
386446

447+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
448+
def test_transpose_with_squeeze4_13(self):
449+
# squeeze the two dims
450+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
451+
axes = self._make_onnx_const(np.array([1, 3], dtype=np.int64), "axes")
452+
node2 = helper.make_node("Squeeze", ["Y", "axes"], ["Z"], name="squeeze")
453+
454+
graph = helper.make_graph(
455+
[node1, node2, axes],
456+
"transpose_with_squeeze",
457+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (3, 1, 1, 5))],
458+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (3, 5))],
459+
)
460+
461+
model_proto = self.make_model(graph, producer_name="onnx-tests")
462+
self.run_transpose_compare(["Z"], {"X": np.random.randn(3, 1, 1, 5).astype(np.float32)},
463+
model_proto, remaining_transpose_num=0)
464+
387465
def test_transpose_with_loop(self):
388466
def _define_loop_graph(external_inputs):
389467
# external_inputs: external node which will be used by this graph
@@ -1090,6 +1168,7 @@ def test_const_fold_node_is_output(self):
10901168
self.run_transpose_compare(["res"], {},
10911169
model_proto, remaining_transpose_num=0)
10921170

1171+
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
10931172
def test_const_fold_unsqueeze_with_const(self):
10941173
shape = (6, 6)
10951174
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
@@ -1109,6 +1188,27 @@ def test_const_fold_unsqueeze_with_const(self):
11091188
self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto,
11101189
"Unsqueeze", 0)
11111190

1191+
@check_opset_min_version(13, "Squeeze/Unsqueeze changed in opset 13")
1192+
def test_const_fold_unsqueeze_with_const_13(self):
1193+
shape = (6, 6)
1194+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
1195+
vals=np.random.randn(*shape).flatten().astype(np.float32))
1196+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
1197+
axes = self._make_onnx_const(np.array([0, 2, 3], dtype=np.int64), "axes")
1198+
node2 = helper.make_node("Unsqueeze", ["const", "axes"], ["value1"])
1199+
node3 = helper.make_node("Add", ["value1", "X"], ["res"])
1200+
1201+
graph = helper.make_graph(
1202+
[node1, node2, node3, axes],
1203+
"test_const_fold_unsqueeze_with_const",
1204+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1,))],
1205+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 6, 1, 1, 6))],
1206+
)
1207+
1208+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1209+
self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto,
1210+
"Unsqueeze", 0)
1211+
11121212
def test_const_fold_cast_with_const(self):
11131213
shape = (6, 6)
11141214
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,

tf2onnx/custom_opsets/string_ops.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tf2onnx import constants, handler
99
from tf2onnx.handler import tf_op
1010
from tf2onnx import utils
11+
from tf2onnx.graph_builder import GraphBuilder
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -16,7 +17,7 @@
1617
@tf_op(["StringSplit", "StringSplitV2"], domain=constants.CONTRIB_OPS_DOMAIN)
1718
class StringOps:
1819
@classmethod
19-
def version_1(cls, ctx, node, **kwargs):
20+
def any_version(cls, opset, ctx, node, **kwargs):
2021
if node.type == "StringSplit":
2122
skip_empty = node.get_attr_value('skip_empty', True)
2223
else:
@@ -25,10 +26,21 @@ def version_1(cls, ctx, node, **kwargs):
2526
node.domain = constants.CONTRIB_OPS_DOMAIN
2627
for a in list(node.attr.keys()):
2728
del node.attr[a]
28-
unsqueeze_node = ctx.make_node("Unsqueeze", [node.input[1]], attr={'axes': [0]})
29+
unsqueeze_node = GraphBuilder(ctx).make_squeeze(
30+
{'data': node.input[1], 'axes': [0]}, return_node=True)
31+
2932
skip_empty_const = ctx.make_const(utils.make_name('skip_empty_const'), np.array([skip_empty], np.bool))
3033
ctx.replace_inputs(node, [node.input[0], unsqueeze_node.output[0], skip_empty_const.output[0]])
3134

35+
@classmethod
36+
def version_1(cls, ctx, node, **kwargs):
37+
cls.any_version(1, ctx, node, **kwargs)
38+
39+
@classmethod
40+
def version_13(cls, ctx, node, **kwargs):
41+
cls.any_version(13, ctx, node, **kwargs)
42+
43+
3244
@tf_op("StringToHashBucketFast", domain=constants.CONTRIB_OPS_DOMAIN)
3345
class StringToHashBucketFast:
3446
@classmethod
@@ -59,7 +71,7 @@ def version_1(cls, ctx, node, **kwargs):
5971
@tf_op("StringJoin", domain=constants.CONTRIB_OPS_DOMAIN)
6072
class StringJoin:
6173
@classmethod
62-
def version_1(cls, ctx, node, **kwargs):
74+
def any_version(cls, opset, ctx, node, **kwargs):
6375
node.domain = constants.CONTRIB_OPS_DOMAIN
6476
separator = node.get_attr_value("separator")
6577
if separator is None:
@@ -76,11 +88,20 @@ def version_1(cls, ctx, node, **kwargs):
7688
if ctx.get_shape(inp) == [] and shape_node is not None:
7789
expand_node = ctx.make_node("Expand", [inp, shape_node.output[0]])
7890
inp = expand_node.output[0]
79-
unsqueeze_node = ctx.make_node("Unsqueeze", [inp], attr={'axes': [0]})
91+
unsqueeze_node = GraphBuilder(ctx).make_squeeze({'data': inp, 'axes': [0]})
8092
unsqueezes.append(unsqueeze_node.output[0])
8193
stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
8294
ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]])
8395

96+
@classmethod
97+
def version_1(cls, ctx, node, **kwargs):
98+
cls.any_version(1, ctx, node, **kwargs)
99+
100+
@classmethod
101+
def version_13(cls, ctx, node, **kwargs):
102+
cls.any_version(13, ctx, node, **kwargs)
103+
104+
84105
@tf_op(["Equal", "NotEqual"], domain=constants.CONTRIB_OPS_DOMAIN)
85106
class StringEqual:
86107
@classmethod

tf2onnx/graph.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,8 +1258,25 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
12581258
break
12591259
return new_node
12601260

1261-
def insert_new_node_on_output(self, op_type, output_name, name=None, inputs=None, domain=None, **kwargs):
1261+
def insert_node_on_output(self, node, output_name=None):
1262+
"""
1263+
The inserted node takes the *output_name* as input and produces a
1264+
new output. The function goes through every node taking *output_name*
1265+
and replaces it by the new output name.
1266+
"""
1267+
if output_name is None:
1268+
output_name = node.input[0]
1269+
new_output = node.output[0]
1270+
1271+
to_replace = [self.get_node_by_name(n) for n in self._output_to_consumers[output_name]]
1272+
to_replace = [n for n in to_replace if n != node]
1273+
self.replace_all_inputs(output_name, new_output, ops=to_replace)
1274+
return node
1275+
1276+
def insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs):
12621277
"""Create and insert a new node into the graph.
1278+
It then calls insert_node_on_output.
1279+
12631280
Args:
12641281
op_type: type for new operation
12651282
output_name: the names of the outputs above us
@@ -1273,6 +1290,7 @@ def insert_new_node_on_output(self, op_type, output_name, name=None, inputs=None
12731290
type(output_name))
12741291
utils.make_sure(isinstance(op_type, six.text_type), "op_type's type is not expected: %s",
12751292
type(op_type))
1293+
utils.make_sure(output_name is not None, "output_name cannot be None for op_type=%r.", op_type)
12761294

12771295
if inputs is None:
12781296
inputs = [output_name]
@@ -1281,11 +1299,7 @@ def insert_new_node_on_output(self, op_type, output_name, name=None, inputs=None
12811299

12821300
new_output = port_name(name)
12831301
new_node = self.make_node(op_type, inputs, attr=kwargs, outputs=[new_output], name=name, domain=domain)
1284-
1285-
to_replace = [self.get_node_by_name(n) for n in self._output_to_consumers[output_name]]
1286-
to_replace = [n for n in to_replace if n != new_node]
1287-
self.replace_all_inputs(output_name, new_output, ops=to_replace)
1288-
return new_node
1302+
return self.insert_node_on_output(new_node, output_name)
12891303

12901304
def find_output_consumers(self, output_name):
12911305
"""Find all nodes consuming a given output."""

tf2onnx/graph_builder.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, graph):
2424
def graph(self):
2525
return self._g
2626

27-
def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
27+
def make_slice(self, kwargs, name=None, shapes=None, dtypes=None, return_node=False):
2828
"""
2929
slice changes its schema at opset 10: it treats some attributes as dynamic input
3030
so this function has to process inputs according to graph's opset version
@@ -77,8 +77,11 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
7777
if input_data != utils.ONNX_EMPTY_INPUT:
7878
utils.make_sure(dtype == self.graph.get_dtype(input_data), "dtype should be same")
7979

80-
return self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name,
81-
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
80+
node = self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name,
81+
outputs=outputs, shapes=shapes, dtypes=dtypes)
82+
if return_node:
83+
return node
84+
return node.output[0]
8285

8386
def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None):
8487
"""
@@ -116,7 +119,7 @@ def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None):
116119
return self.graph.make_node(op_type="ReduceSum", inputs=inputs, attr=attr, name=name,
117120
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
118121

119-
def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None):
122+
def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None, return_node=False, op_name_scope=None):
120123
"""
121124
Squeeze changes its schema at opset 13: it treats axes as a dynamic input
122125
kwargs: key could be ["data", "axes"].
@@ -150,10 +153,14 @@ def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None):
150153
while inputs[-1] == utils.ONNX_EMPTY_INPUT:
151154
inputs = inputs[:-1]
152155

153-
return self.graph.make_node(op_type="Squeeze", inputs=inputs, attr=attr, name=name,
154-
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
156+
node = self.graph.make_node(op_type="Squeeze", inputs=inputs, attr=attr, name=name,
157+
outputs=outputs, shapes=shapes, dtypes=dtypes,
158+
op_name_scope=op_name_scope)
159+
if return_node:
160+
return node
161+
return node.output[0]
155162

156-
def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None):
163+
def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None, return_node=False, op_name_scope=None):
157164
"""
158165
Unsqueeze changes its schema at opset 13: it treats axes as a dynamic input
159166
kwargs: key could be ["data", "axes"].
@@ -187,8 +194,12 @@ def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None):
187194
while inputs[-1] == utils.ONNX_EMPTY_INPUT:
188195
inputs = inputs[:-1]
189196

190-
return self.graph.make_node(op_type="Unsqueeze", inputs=inputs, attr=attr, name=name,
191-
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
197+
node = self.graph.make_node(op_type="Unsqueeze", inputs=inputs, attr=attr, name=name,
198+
outputs=outputs, shapes=shapes, dtypes=dtypes,
199+
op_name_scope=op_name_scope)
200+
if return_node:
201+
return node
202+
return node.output[0]
192203

193204
def convert_to_input(self, tensor, const_name, is_optional=False, dtype=None):
194205
"""in ONNX, input shold come from node, so it must be a string"""

tf2onnx/onnx_opset/controlflow.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tf2onnx import utils
2020
from tf2onnx.handler import tf_op
2121
from tf2onnx.tf_loader import find_function
22+
from tf2onnx.graph_builder import GraphBuilder
2223

2324

2425
logger = logging.getLogger(__name__)
@@ -289,6 +290,20 @@ def version_7(cls, ctx, node, **kwargs):
289290
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], name=node.child_name(), axes=[0])
290291
ctx.insert_new_node_on_output("Squeeze", node.output[0], name=node.child_name(), axes=[0])
291292

293+
@classmethod
294+
def version_13(cls, ctx, node, **kwargs):
295+
ctx.ta_reads.append(node.input[0])
296+
node.type = "Gather"
297+
ctx.replace_inputs(node, [node.input[0], node.input[1]])
298+
299+
g = GraphBuilder(ctx)
300+
301+
usq_node = g.make_unsqueeze({"axes": [0], 'name': node.child_name(), 'data': node.input[1]}, return_node=True)
302+
ctx.insert_node_on_output(usq_node)
303+
304+
sq_node = g.make_squeeze({"axes": [0], 'name': node.child_name(), 'data': node.output[0]}, return_node=True)
305+
ctx.insert_node_on_output(sq_node)
306+
292307

293308
@tf_op(["TensorListLength"])
294309
class TensorListLength:
@@ -607,7 +622,7 @@ def inline_subgraph(parent, g, scope, binding):
607622
parent.set_dtype(name, g.get_dtype(name))
608623
parent.set_shape(name, g.get_shape(name))
609624

610-
return g.outputs
625+
return g.outputs
611626

612627

613628
def parameter_binding(g, inputs, state_vars=None):

0 commit comments

Comments
 (0)