Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,12 +1044,31 @@ def func(x1, x2):
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})

def test_sequeeze_no_axis_specified(self):
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1))
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 1, 2, 1, 1))
def func(x):
x_ = tf.squeeze(x)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

def test_sequeeze_no_axis(self):
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2))
def func(x):
x_ = tf.squeeze(x)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(11, "Pad")
def test_sequeeze_no_axis_specified_unknown_rank(self):
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
y_val = np.array([2, 1, 2, 1, 1], dtype=np.int64)
z_val = np.zeros((1, 2), dtype=np.int64)
def func(x, y, z):
y_ = tf.pad(y, z)
x_ = tf.reshape(x, y_)
x_ = tf.squeeze(x_)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})

def test_sequeeze_positive_axis(self):
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1))
def func(x):
Expand All @@ -1071,6 +1090,18 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(11, "Squeeze")
def test_sequeeze_mixed_axis_unknown_rank(self):
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
y_val = np.array([2, 1, 2, 1, 1], dtype=np.int64)
z_val = np.zeros((1, 2), dtype=np.int64)
def func(x, y, z):
y_ = tf.pad(y, z)
x_ = tf.reshape(x, y_)
x_ = tf.squeeze(x_, [1, -1])
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})

def test_transpose(self):
x_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).reshape((2, 3))
def func(x):
Expand Down
127 changes: 119 additions & 8 deletions tf2onnx/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
if self.graph.opset < 10:
# "data" is string
# "starts", "ends" and "axes" are attributes, and "axes" is optional.
inputs = [kwargs.pop("data")]
data = kwargs.pop("data")
starts = self.convert_to_attribute(kwargs.pop("starts"))
ends = self.convert_to_attribute(kwargs.pop("ends"))
axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True)
attr = {"starts": starts, "ends": ends, "axes": axes}
inputs = [data]
else:
# slice-10 has 3 required inputs "data", "starts", "ends"l
# and 2 optional inputs "axes", "steps"
# input sequence should be "data", "starts", "ends", "axes", "steps"
attr = {}
data = self.convert_to_input(kwargs.pop("data"))
starts = self.convert_to_input(kwargs.pop("starts"), dtype=np.int64)
ends = self.convert_to_input(kwargs.pop("ends"), dtype=np.int64)
axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True, dtype=np.int64)
steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True, dtype=np.int64)
data = kwargs.pop("data")
starts = self.convert_to_input(kwargs.pop("starts"), "const_starts", dtype=np.int64)
ends = self.convert_to_input(kwargs.pop("ends"), "const_ends", dtype=np.int64)
axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64)
steps = self.convert_to_input(kwargs.pop("steps", None), "const_steps", is_optional=True, dtype=np.int64)
inputs = [data, starts, ends, axes, steps]

# pro-process inputs and attr
Expand Down Expand Up @@ -79,7 +80,117 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
return self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name,
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]

def convert_to_input(self, tensor, is_optional=False, dtype=None):
def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None):
"""
ReduceSum changes its schema at opset 13: it treats some axes as dynamic input
kwargs: key could be ["data", "axes", "keepdims", "noop_with_empty_axes", "outputs"].
"""
outputs = kwargs.pop("outputs", None)

if self.graph.opset < 13:
data = kwargs.pop("data")
axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True)
keepdims = kwargs.pop("keepdims", 1)
noop_with_empty_axes = kwargs.pop("noop_with_empty_axes", 0)
if noop_with_empty_axes == 0 and axes == []:
axes = None
attr = {"axes": axes, "keepdims": keepdims}
inputs = [data]
else:
keepdims = kwargs.pop("keepdims", 1)
noop_with_empty_axes = kwargs.pop("noop_with_empty_axes", 0)
data = self.convert_to_input(kwargs.pop("data"), "const_data")
axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64)
attr = {"keepdims": keepdims, "noop_with_empty_axes": noop_with_empty_axes}
inputs = [data, axes]

if kwargs:
logger.warning("kwargs contains un-used key")

new_attr = {}
for key, val in attr.items():
if val is not None:
new_attr[key] = val
attr = new_attr

return self.graph.make_node(op_type="ReduceSum", inputs=inputs, attr=attr, name=name,
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]

def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None):
"""
Squeeze changes its schema at opset 13: it treats axes as a dynamic input
kwargs: key could be ["data", "axes"].
"""
outputs = kwargs.pop("outputs", None)

if self.graph.opset < 13:
data = kwargs.pop("data")
axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True)
attr = {"axes": axes}
inputs = [data]
else:
data = kwargs.pop("data")
axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64)
attr = {}
inputs = [data, axes]

if kwargs:
logger.warning("kwargs contains un-used key")

new_attr = {}
for key, val in attr.items():
if val is not None:
new_attr[key] = val
attr = new_attr

for ind, val in enumerate(inputs):
if val is None:
inputs[ind] = utils.ONNX_EMPTY_INPUT # empty string means no connection in ONNX
# remove tailing ""
while inputs[-1] == utils.ONNX_EMPTY_INPUT:
inputs = inputs[:-1]

return self.graph.make_node(op_type="Squeeze", inputs=inputs, attr=attr, name=name,
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]

def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None):
"""
Unsqueeze changes its schema at opset 13: it treats axes as a dynamic input
kwargs: key could be ["data", "axes"].
"""
outputs = kwargs.pop("outputs", None)

if self.graph.opset < 13:
data = kwargs.pop("data")
axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True)
attr = {"axes": axes}
inputs = [data]
else:
data = kwargs.pop("data")
axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64)
attr = {}
inputs = [data, axes]

if kwargs:
logger.warning("kwargs contains un-used key")

new_attr = {}
for key, val in attr.items():
if val is not None:
new_attr[key] = val
attr = new_attr

for ind, val in enumerate(inputs):
if val is None:
inputs[ind] = utils.ONNX_EMPTY_INPUT # empty string means no connection in ONNX
# remove tailing ""
while inputs[-1] == utils.ONNX_EMPTY_INPUT:
inputs = inputs[:-1]

return self.graph.make_node(op_type="Unsqueeze", inputs=inputs, attr=attr, name=name,
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]

def convert_to_input(self, tensor, const_name, is_optional=False, dtype=None):
"""in ONNX, input shold come from node, so it must be a string"""
if is_optional and tensor is None:
return None
Expand All @@ -88,7 +199,7 @@ def convert_to_input(self, tensor, is_optional=False, dtype=None):

res = tensor
if isinstance(tensor, list):
res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor, dtype)).output[0]
res = self.graph.make_const(utils.make_name(const_name), np.array(tensor, dtype)).output[0]

utils.make_sure(isinstance(res, str), "input is a dynamic input, so a str is needed")

Expand Down
16 changes: 10 additions & 6 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,10 +1141,11 @@ def _make_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):
log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=logit.output)
# implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))
mul1 = ctx.make_node(op_type="Mul", inputs=[label.output[0], log_softmax.output[0]])
reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [-1]})
reduce_sum_output = GraphBuilder(ctx).make_reduce_sum(
{"data": mul1.output[0], "axes": [-1], "keepdims": 1, "noop_with_empty_axes": 1})
const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"),
np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype]))
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]])
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum_output])
shapes = tf_ori_node.output_shapes
dtypes = tf_ori_node.output_dtypes
ctx.remove_node(tf_ori_node.name)
Expand Down Expand Up @@ -1223,9 +1224,11 @@ def _make_sparse_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_nod
# logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul))
# >> -log(masked_sum/sum)
logit_exp = ctx.make_node(op_type="Exp", inputs=[logit]).output[0]
logit_exp_sum = ctx.make_node(op_type="ReduceSum", inputs=[logit_exp], attr={"axes": [-1], "keepdims": 0}).output[0]
logit_exp_sum = GraphBuilder(ctx).make_reduce_sum(
{"data": logit_exp, "axes": [-1], "keepdims": 0, "noop_with_empty_axes": 1})
masked = ctx.make_node(op_type="Mul", inputs=[label, logit_exp]).output[0]
masked_sum = ctx.make_node(op_type="ReduceSum", inputs=[masked], attr={"axes": [-1], "keepdims": 0}).output[0]
masked_sum = GraphBuilder(ctx).make_reduce_sum(
{"data": masked, "axes": [-1], "keepdims": 0, "noop_with_empty_axes": 1})
probability = ctx.make_node(op_type="Div", inputs=[masked_sum, logit_exp_sum]).output[0]
log_prob = ctx.make_node(op_type="Log", inputs=[probability]).output[0]
const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"),
Expand Down Expand Up @@ -1266,10 +1269,11 @@ def version_7(cls, ctx, node, **kwargs):
log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=[logit_name])
# implement tf.multiply(np.float32(-1.0), tf.reduce_sum(tf.multiply(one_hot, log_softmax), axis=1))
mul1 = ctx.make_node(op_type="Mul", inputs=[onehot.output[0], log_softmax.output[0]])
reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [1]})
reduce_sum_output = GraphBuilder(ctx).make_reduce_sum(
{"data": mul1.output[0], "axes": [1], "keepdims": 1, "noop_with_empty_axes": 1})
const_name = utils.make_name("const_negative_one")
const_negative_one = ctx.make_const(name=const_name, np_val=np.array(-1).astype(dtype))
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]])
mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum_output])

shapes = node.output_shapes
dtypes = node.output_dtypes
Expand Down
47 changes: 32 additions & 15 deletions tf2onnx/onnx_opset/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,26 @@ def version_1(cls, ctx, node, **kwargs):

node.set_attr("axes", axes)
ctx.remove_input(node, node.input[1], 1)
keep_dims = node.get_attr("keep_dims")
if keep_dims:
del node.attr['keep_dims']
node.set_attr("keepdims", keep_dims.i)
keep_dims = node.get_attr_value("keep_dims", 0)
node.set_attr("keepdims", keep_dims)
del node.attr['keep_dims']

@classmethod
def version_11(cls, ctx, node, **kwargs):
# Opset 11 supports negative axis, but core logic is same
cls.version_1(ctx, node, **kwargs)

@classmethod
def version_13(cls, ctx, node, **kwargs):
if node.type == "ReduceSum":
keep_dims = node.get_attr_value("keep_dims", 0)
node.set_attr("keepdims", keep_dims)
del node.attr['keep_dims']
node.set_attr("noop_with_empty_axes", 1)
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
else:
cls.version_11(ctx, node, **kwargs)

@tf_op(["ArgMax", "ArgMin"])
class ArgMax:
Expand Down Expand Up @@ -114,15 +124,20 @@ def version_6(cls, ctx, node, **kwargs):
cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT})
keepdims = helper.get_attribute_value(node.get_attr("keep_dims"))
op_type = "ReduceMin" if node.type == "All" else "ReduceSum"
reduce_node = ctx.make_node(op_type=op_type, inputs=cast.output,
attr={"axes": reduce_dim, "keepdims": keepdims})

if op_type == "ReduceSum":
reduce_node_output = GraphBuilder(ctx).make_reduce_sum(
{"data": cast.output[0], "axes": reduce_dim, "keepdims": keepdims, "noop_with_empty_axes": 1})
else:
reduce_node_output = ctx.make_node(op_type=op_type, inputs=cast.output,
attr={"axes": reduce_dim, "keepdims": keepdims}).output[0]

zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32))

shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
ctx.make_node(op_type="Greater", inputs=[reduce_node.output[0], zero_node.output[0]],
ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]],
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)


Expand Down Expand Up @@ -198,17 +213,19 @@ def version_9(cls, ctx, node, **kwargs):
one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments, onehot_values.output[0]],
attr={'axis': 0})
if node.type == "SegmentMean":
scaling_node = ctx.make_node("ReduceSum", [one_hot_node.output[0]], attr={'axes': [1], 'keepdims': 0})
scaling_node_output = GraphBuilder(ctx).make_reduce_sum(
{"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1})
elif node.type == "SegmentSqrtN":
seg_cnts_node = ctx.make_node("ReduceSum", [one_hot_node.output[0]], attr={'axes': [1], 'keepdims': 0})
scaling_node = ctx.make_node("Sqrt", [seg_cnts_node.output[0]])
seg_cnts_node_output = GraphBuilder(ctx).make_reduce_sum(
{"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1})
scaling_node_output = ctx.make_node("Sqrt", [seg_cnts_node_output]).output[0]
else:
scaling_node = None
scaling_node_output = None

if scaling_node and num_segments_specified:
if scaling_node_output is not None and num_segments_specified:
# If empty segments are possible, we must avoid division by zero
const_one_float = ctx.make_const(utils.make_name("const_one_float"), np.array(1, dtype=np.float32))
scaling_node = ctx.make_node("Max", [scaling_node.output[0], const_one_float.output[0]])
scaling_node_output = ctx.make_node("Max", [scaling_node_output, const_one_float.output[0]]).output[0]


if onnx_op == "ReduceSum":
Expand All @@ -226,8 +243,8 @@ def version_9(cls, ctx, node, **kwargs):

# Shapes [s, n] * [n, P] => [s, P]
product = ctx.make_node("MatMul", [one_hot_cast.output[0], data_reshape.output[0]], op_name_scope=node.name)
if scaling_node is not None:
scaling_node_unsqueeze = ctx.make_node("Unsqueeze", [scaling_node.output[0]], attr={'axes': [1]})
if scaling_node_output is not None:
scaling_node_unsqueeze = ctx.make_node("Unsqueeze", [scaling_node_output], attr={'axes': [1]})
product = ctx.make_node("Div", [product.output[0], scaling_node_unsqueeze.output[0]])

# Create new shape [0, a, b, ..., c]
Expand Down
33 changes: 15 additions & 18 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,28 +183,25 @@ class Squeeze:
def version_1(cls, ctx, node, **kwargs):
# T output = Squeeze(T input, @list(int) squeeze_dims)
# T squeezed = Squeeze(T data, @AttrType.INTS axes), axes are list of positive integers.
axis = node.get_attr("axis")
if not axis:
axis = node.get_attr("squeeze_dims")
if axis:
del node.attr["squeeze_dims"]
axes = node.get_attr_value("squeeze_dims")
if axes is None:
axes = []
else:
del node.attr["axis"]
del node.attr["squeeze_dims"]

if axis and axis.ints:
axis = axis.ints
neg_axis = any([val < 0 for val in axis])
if neg_axis:
# TF uses empty axes to indicate that all 1 dims should be squeezed
if len(axes) > 0:
neg_axis = any([val < 0 for val in axes])
if neg_axis and ctx.opset < 11:
shape = ctx.get_shape(node.input[0])
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
utils.make_sure(shape is not None, "squeeze with negative axes and unknown rank requires opset >= 11")
shape_len = len(shape)
axis = [a + shape_len if a < 0 else a for a in axis]
else:
shape = ctx.get_shape(node.input[0])
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
axis = [i for i, j in enumerate(shape) if j == 1]
if not axis: axis = [0]
node.set_attr("axes", axis)
axes = [a + shape_len if a < 0 else a for a in axes]
if ctx.opset < 13:
node.set_attr("axes", axes)
else:
axes_const = ctx.make_const(utils.make_name("axes_const"), np.array(axes, dtype=np.int64))
ctx.replace_inputs(node, [node.input[0], axes_const.output[0]])

@classmethod
def version_11(cls, ctx, node, **kwargs):
Expand Down