Skip to content

Commit e1a1114

Browse files
Merge pull request #1216 from onnx/tom/SqueezeUnsqueeze
Tom/squeeze unsqueeze
2 parents 5021d4c + bc62269 commit e1a1114

File tree

3 files changed

+121
-19
lines changed

3 files changed

+121
-19
lines changed

tests/test_backend.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,12 +1044,31 @@ def func(x1, x2):
10441044
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
10451045

10461046
def test_sequeeze_no_axis_specified(self):
1047-
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1))
1047+
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 1, 2, 1, 1))
1048+
def func(x):
1049+
x_ = tf.squeeze(x)
1050+
return tf.identity(x_, name=_TFOUTPUT)
1051+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1052+
1053+
def test_sequeeze_no_axis(self):
1054+
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2))
10481055
def func(x):
10491056
x_ = tf.squeeze(x)
10501057
return tf.identity(x_, name=_TFOUTPUT)
10511058
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
10521059

1060+
@check_opset_min_version(11, "Pad")
1061+
def test_sequeeze_no_axis_specified_unknown_rank(self):
1062+
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
1063+
y_val = np.array([2, 1, 2, 1, 1], dtype=np.int64)
1064+
z_val = np.zeros((1, 2), dtype=np.int64)
1065+
def func(x, y, z):
1066+
y_ = tf.pad(y, z)
1067+
x_ = tf.reshape(x, y_)
1068+
x_ = tf.squeeze(x_)
1069+
return tf.identity(x_, name=_TFOUTPUT)
1070+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
1071+
10531072
def test_sequeeze_positive_axis(self):
10541073
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1))
10551074
def func(x):
@@ -1071,6 +1090,18 @@ def func(x):
10711090
return tf.identity(x_, name=_TFOUTPUT)
10721091
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
10731092

1093+
@check_opset_min_version(11, "Squeeze")
1094+
def test_sequeeze_mixed_axis_unknown_rank(self):
1095+
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
1096+
y_val = np.array([2, 1, 2, 1, 1], dtype=np.int64)
1097+
z_val = np.zeros((1, 2), dtype=np.int64)
1098+
def func(x, y, z):
1099+
y_ = tf.pad(y, z)
1100+
x_ = tf.reshape(x, y_)
1101+
x_ = tf.squeeze(x_, [1, -1])
1102+
return tf.identity(x_, name=_TFOUTPUT)
1103+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
1104+
10741105
def test_transpose(self):
10751106
x_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).reshape((2, 3))
10761107
def func(x):

tf2onnx/graph_builder.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,80 @@ def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None):
116116
return self.graph.make_node(op_type="ReduceSum", inputs=inputs, attr=attr, name=name,
117117
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
118118

119+
def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None):
120+
"""
121+
Squeeze changes its schema at opset 13: it treats axes as a dynamic input
122+
kwargs: key could be ["data", "axes"].
123+
"""
124+
outputs = kwargs.pop("outputs", None)
125+
126+
if self.graph.opset < 13:
127+
data = kwargs.pop("data")
128+
axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True)
129+
attr = {"axes": axes}
130+
inputs = [data]
131+
else:
132+
data = kwargs.pop("data")
133+
axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64)
134+
attr = {}
135+
inputs = [data, axes]
136+
137+
if kwargs:
138+
logger.warning("kwargs contains un-used key")
139+
140+
new_attr = {}
141+
for key, val in attr.items():
142+
if val is not None:
143+
new_attr[key] = val
144+
attr = new_attr
145+
146+
for ind, val in enumerate(inputs):
147+
if val is None:
148+
inputs[ind] = utils.ONNX_EMPTY_INPUT # empty string means no connection in ONNX
149+
# remove tailing ""
150+
while inputs[-1] == utils.ONNX_EMPTY_INPUT:
151+
inputs = inputs[:-1]
152+
153+
return self.graph.make_node(op_type="Squeeze", inputs=inputs, attr=attr, name=name,
154+
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
155+
156+
def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None):
157+
"""
158+
Unsqueeze changes its schema at opset 13: it treats axes as a dynamic input
159+
kwargs: key could be ["data", "axes"].
160+
"""
161+
outputs = kwargs.pop("outputs", None)
162+
163+
if self.graph.opset < 13:
164+
data = kwargs.pop("data")
165+
axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True)
166+
attr = {"axes": axes}
167+
inputs = [data]
168+
else:
169+
data = kwargs.pop("data")
170+
axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64)
171+
attr = {}
172+
inputs = [data, axes]
173+
174+
if kwargs:
175+
logger.warning("kwargs contains un-used key")
176+
177+
new_attr = {}
178+
for key, val in attr.items():
179+
if val is not None:
180+
new_attr[key] = val
181+
attr = new_attr
182+
183+
for ind, val in enumerate(inputs):
184+
if val is None:
185+
inputs[ind] = utils.ONNX_EMPTY_INPUT # empty string means no connection in ONNX
186+
# remove tailing ""
187+
while inputs[-1] == utils.ONNX_EMPTY_INPUT:
188+
inputs = inputs[:-1]
189+
190+
return self.graph.make_node(op_type="Unsqueeze", inputs=inputs, attr=attr, name=name,
191+
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
192+
119193
def convert_to_input(self, tensor, const_name, is_optional=False, dtype=None):
120194
"""in ONNX, input shold come from node, so it must be a string"""
121195
if is_optional and tensor is None:

tf2onnx/onnx_opset/tensor.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -183,28 +183,25 @@ class Squeeze:
183183
def version_1(cls, ctx, node, **kwargs):
184184
# T output = Squeeze(T input, @list(int) squeeze_dims)
185185
# T squeezed = Squeeze(T data, @AttrType.INTS axes), axes are list of positive integers.
186-
axis = node.get_attr("axis")
187-
if not axis:
188-
axis = node.get_attr("squeeze_dims")
189-
if axis:
190-
del node.attr["squeeze_dims"]
186+
axes = node.get_attr_value("squeeze_dims")
187+
if axes is None:
188+
axes = []
191189
else:
192-
del node.attr["axis"]
190+
del node.attr["squeeze_dims"]
193191

194-
if axis and axis.ints:
195-
axis = axis.ints
196-
neg_axis = any([val < 0 for val in axis])
197-
if neg_axis:
192+
# TF uses empty axes to indicate that all 1 dims should be squeezed
193+
if len(axes) > 0:
194+
neg_axis = any([val < 0 for val in axes])
195+
if neg_axis and ctx.opset < 11:
198196
shape = ctx.get_shape(node.input[0])
199-
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
197+
utils.make_sure(shape is not None, "squeeze with negative axes and unknown rank requires opset >= 11")
200198
shape_len = len(shape)
201-
axis = [a + shape_len if a < 0 else a for a in axis]
202-
else:
203-
shape = ctx.get_shape(node.input[0])
204-
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
205-
axis = [i for i, j in enumerate(shape) if j == 1]
206-
if not axis: axis = [0]
207-
node.set_attr("axes", axis)
199+
axes = [a + shape_len if a < 0 else a for a in axes]
200+
if ctx.opset < 13:
201+
node.set_attr("axes", axes)
202+
else:
203+
axes_const = ctx.make_const(utils.make_name("axes_const"), np.array(axes, dtype=np.int64))
204+
ctx.replace_inputs(node, [node.input[0], axes_const.output[0]])
208205

209206
@classmethod
210207
def version_11(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)