diff --git a/tests/test_backend.py b/tests/test_backend.py index 9edbd939e..84982bfc8 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1044,12 +1044,31 @@ def func(x1, x2): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2}) def test_sequeeze_no_axis_specified(self): - x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1)) + x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 1, 2, 1, 1)) + def func(x): + x_ = tf.squeeze(x) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + + def test_sequeeze_no_axis(self): + x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2)) def func(x): x_ = tf.squeeze(x) return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(11, "Pad") + def test_sequeeze_no_axis_specified_unknown_rank(self): + x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + y_val = np.array([2, 1, 2, 1, 1], dtype=np.int64) + z_val = np.zeros((1, 2), dtype=np.int64) + def func(x, y, z): + y_ = tf.pad(y, z) + x_ = tf.reshape(x, y_) + x_ = tf.squeeze(x_) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val}) + def test_sequeeze_positive_axis(self): x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 2, 1)) def func(x): @@ -1071,6 +1090,18 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(11, "Squeeze") + def test_sequeeze_mixed_axis_unknown_rank(self): + x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + y_val = np.array([2, 1, 2, 1, 1], dtype=np.int64) + z_val = np.zeros((1, 2), dtype=np.int64) + def func(x, y, z): + y_ = tf.pad(y, z) + x_ = tf.reshape(x, y_) + x_ = tf.squeeze(x_, [1, -1]) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val}) + def test_transpose(self): x_val = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).reshape((2, 3)) def func(x): diff --git a/tf2onnx/graph_builder.py b/tf2onnx/graph_builder.py index 6b1f252c3..7fa845c94 100644 --- a/tf2onnx/graph_builder.py +++ b/tf2onnx/graph_builder.py @@ -36,21 +36,22 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None): if self.graph.opset < 10: # "data" is string # "starts", "ends" and "axes" are attributes, and "axes" is optional. - inputs = [kwargs.pop("data")] + data = kwargs.pop("data") starts = self.convert_to_attribute(kwargs.pop("starts")) ends = self.convert_to_attribute(kwargs.pop("ends")) axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True) attr = {"starts": starts, "ends": ends, "axes": axes} + inputs = [data] else: # slice-10 has 3 required inputs "data", "starts", "ends"l # and 2 optional inputs "axes", "steps" # input sequence should be "data", "starts", "ends", "axes", "steps" attr = {} - data = self.convert_to_input(kwargs.pop("data")) - starts = self.convert_to_input(kwargs.pop("starts"), dtype=np.int64) - ends = self.convert_to_input(kwargs.pop("ends"), dtype=np.int64) - axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True, dtype=np.int64) - steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True, dtype=np.int64) + data = kwargs.pop("data") + starts = self.convert_to_input(kwargs.pop("starts"), "const_starts", dtype=np.int64) + ends = self.convert_to_input(kwargs.pop("ends"), "const_ends", dtype=np.int64) + axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64) + steps = self.convert_to_input(kwargs.pop("steps", None), "const_steps", is_optional=True, dtype=np.int64) inputs = [data, starts, ends, axes, steps] # pro-process inputs and attr @@ -79,7 +80,117 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None): return self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name, outputs=outputs, shapes=shapes, dtypes=dtypes).output[0] - def convert_to_input(self, tensor, is_optional=False, dtype=None): + def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None): + """ + ReduceSum changes its schema at opset 13: it treats some axes as dynamic input + kwargs: key could be ["data", "axes", "keepdims", "noop_with_empty_axes", "outputs"]. + """ + outputs = kwargs.pop("outputs", None) + + if self.graph.opset < 13: + data = kwargs.pop("data") + axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True) + keepdims = kwargs.pop("keepdims", 1) + noop_with_empty_axes = kwargs.pop("noop_with_empty_axes", 0) + if noop_with_empty_axes == 0 and axes == []: + axes = None + attr = {"axes": axes, "keepdims": keepdims} + inputs = [data] + else: + keepdims = kwargs.pop("keepdims", 1) + noop_with_empty_axes = kwargs.pop("noop_with_empty_axes", 0) + data = self.convert_to_input(kwargs.pop("data"), "const_data") + axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64) + attr = {"keepdims": keepdims, "noop_with_empty_axes": noop_with_empty_axes} + inputs = [data, axes] + + if kwargs: + logger.warning("kwargs contains un-used key") + + new_attr = {} + for key, val in attr.items(): + if val is not None: + new_attr[key] = val + attr = new_attr + + return self.graph.make_node(op_type="ReduceSum", inputs=inputs, attr=attr, name=name, + outputs=outputs, shapes=shapes, dtypes=dtypes).output[0] + + def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None): + """ + Squeeze changes its schema at opset 13: it treats axes as a dynamic input + kwargs: key could be ["data", "axes"]. + """ + outputs = kwargs.pop("outputs", None) + + if self.graph.opset < 13: + data = kwargs.pop("data") + axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True) + attr = {"axes": axes} + inputs = [data] + else: + data = kwargs.pop("data") + axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64) + attr = {} + inputs = [data, axes] + + if kwargs: + logger.warning("kwargs contains un-used key") + + new_attr = {} + for key, val in attr.items(): + if val is not None: + new_attr[key] = val + attr = new_attr + + for ind, val in enumerate(inputs): + if val is None: + inputs[ind] = utils.ONNX_EMPTY_INPUT # empty string means no connection in ONNX + # remove tailing "" + while inputs[-1] == utils.ONNX_EMPTY_INPUT: + inputs = inputs[:-1] + + return self.graph.make_node(op_type="Squeeze", inputs=inputs, attr=attr, name=name, + outputs=outputs, shapes=shapes, dtypes=dtypes).output[0] + + def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None): + """ + Unsqueeze changes its schema at opset 13: it treats axes as a dynamic input + kwargs: key could be ["data", "axes"]. + """ + outputs = kwargs.pop("outputs", None) + + if self.graph.opset < 13: + data = kwargs.pop("data") + axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True) + attr = {"axes": axes} + inputs = [data] + else: + data = kwargs.pop("data") + axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64) + attr = {} + inputs = [data, axes] + + if kwargs: + logger.warning("kwargs contains un-used key") + + new_attr = {} + for key, val in attr.items(): + if val is not None: + new_attr[key] = val + attr = new_attr + + for ind, val in enumerate(inputs): + if val is None: + inputs[ind] = utils.ONNX_EMPTY_INPUT # empty string means no connection in ONNX + # remove tailing "" + while inputs[-1] == utils.ONNX_EMPTY_INPUT: + inputs = inputs[:-1] + + return self.graph.make_node(op_type="Unsqueeze", inputs=inputs, attr=attr, name=name, + outputs=outputs, shapes=shapes, dtypes=dtypes).output[0] + + def convert_to_input(self, tensor, const_name, is_optional=False, dtype=None): """in ONNX, input shold come from node, so it must be a string""" if is_optional and tensor is None: return None @@ -88,7 +199,7 @@ def convert_to_input(self, tensor, is_optional=False, dtype=None): res = tensor if isinstance(tensor, list): - res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor, dtype)).output[0] + res = self.graph.make_const(utils.make_name(const_name), np.array(tensor, dtype)).output[0] utils.make_sure(isinstance(res, str), "input is a dynamic input, so a str is needed") diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 892a09eb6..d89e15349 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -1141,10 +1141,11 @@ def _make_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node): log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=logit.output) # implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1)) mul1 = ctx.make_node(op_type="Mul", inputs=[label.output[0], log_softmax.output[0]]) - reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [-1]}) + reduce_sum_output = GraphBuilder(ctx).make_reduce_sum( + {"data": mul1.output[0], "axes": [-1], "keepdims": 1, "noop_with_empty_axes": 1}) const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"), np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype])) - mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]]) + mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum_output]) shapes = tf_ori_node.output_shapes dtypes = tf_ori_node.output_dtypes ctx.remove_node(tf_ori_node.name) @@ -1223,9 +1224,11 @@ def _make_sparse_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_nod # logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul)) # >> -log(masked_sum/sum) logit_exp = ctx.make_node(op_type="Exp", inputs=[logit]).output[0] - logit_exp_sum = ctx.make_node(op_type="ReduceSum", inputs=[logit_exp], attr={"axes": [-1], "keepdims": 0}).output[0] + logit_exp_sum = GraphBuilder(ctx).make_reduce_sum( + {"data": logit_exp, "axes": [-1], "keepdims": 0, "noop_with_empty_axes": 1}) masked = ctx.make_node(op_type="Mul", inputs=[label, logit_exp]).output[0] - masked_sum = ctx.make_node(op_type="ReduceSum", inputs=[masked], attr={"axes": [-1], "keepdims": 0}).output[0] + masked_sum = GraphBuilder(ctx).make_reduce_sum( + {"data": masked, "axes": [-1], "keepdims": 0, "noop_with_empty_axes": 1}) probability = ctx.make_node(op_type="Div", inputs=[masked_sum, logit_exp_sum]).output[0] log_prob = ctx.make_node(op_type="Log", inputs=[probability]).output[0] const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"), @@ -1266,10 +1269,11 @@ def version_7(cls, ctx, node, **kwargs): log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=[logit_name]) # implement tf.multiply(np.float32(-1.0), tf.reduce_sum(tf.multiply(one_hot, log_softmax), axis=1)) mul1 = ctx.make_node(op_type="Mul", inputs=[onehot.output[0], log_softmax.output[0]]) - reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [1]}) + reduce_sum_output = GraphBuilder(ctx).make_reduce_sum( + {"data": mul1.output[0], "axes": [1], "keepdims": 1, "noop_with_empty_axes": 1}) const_name = utils.make_name("const_negative_one") const_negative_one = ctx.make_const(name=const_name, np_val=np.array(-1).astype(dtype)) - mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]]) + mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum_output]) shapes = node.output_shapes dtypes = node.output_dtypes diff --git a/tf2onnx/onnx_opset/reduction.py b/tf2onnx/onnx_opset/reduction.py index 5853e0772..7e2e212b3 100644 --- a/tf2onnx/onnx_opset/reduction.py +++ b/tf2onnx/onnx_opset/reduction.py @@ -45,16 +45,26 @@ def version_1(cls, ctx, node, **kwargs): node.set_attr("axes", axes) ctx.remove_input(node, node.input[1], 1) - keep_dims = node.get_attr("keep_dims") - if keep_dims: - del node.attr['keep_dims'] - node.set_attr("keepdims", keep_dims.i) + keep_dims = node.get_attr_value("keep_dims", 0) + node.set_attr("keepdims", keep_dims) + del node.attr['keep_dims'] @classmethod def version_11(cls, ctx, node, **kwargs): # Opset 11 supports negative axis, but core logic is same cls.version_1(ctx, node, **kwargs) + @classmethod + def version_13(cls, ctx, node, **kwargs): + if node.type == "ReduceSum": + keep_dims = node.get_attr_value("keep_dims", 0) + node.set_attr("keepdims", keep_dims) + del node.attr['keep_dims'] + node.set_attr("noop_with_empty_axes", 1) + if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64: + ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64) + else: + cls.version_11(ctx, node, **kwargs) @tf_op(["ArgMax", "ArgMin"]) class ArgMax: @@ -114,15 +124,20 @@ def version_6(cls, ctx, node, **kwargs): cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT}) keepdims = helper.get_attribute_value(node.get_attr("keep_dims")) op_type = "ReduceMin" if node.type == "All" else "ReduceSum" - reduce_node = ctx.make_node(op_type=op_type, inputs=cast.output, - attr={"axes": reduce_dim, "keepdims": keepdims}) + + if op_type == "ReduceSum": + reduce_node_output = GraphBuilder(ctx).make_reduce_sum( + {"data": cast.output[0], "axes": reduce_dim, "keepdims": keepdims, "noop_with_empty_axes": 1}) + else: + reduce_node_output = ctx.make_node(op_type=op_type, inputs=cast.output, + attr={"axes": reduce_dim, "keepdims": keepdims}).output[0] zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32)) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) - ctx.make_node(op_type="Greater", inputs=[reduce_node.output[0], zero_node.output[0]], + ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]], name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes) @@ -198,17 +213,19 @@ def version_9(cls, ctx, node, **kwargs): one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments, onehot_values.output[0]], attr={'axis': 0}) if node.type == "SegmentMean": - scaling_node = ctx.make_node("ReduceSum", [one_hot_node.output[0]], attr={'axes': [1], 'keepdims': 0}) + scaling_node_output = GraphBuilder(ctx).make_reduce_sum( + {"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1}) elif node.type == "SegmentSqrtN": - seg_cnts_node = ctx.make_node("ReduceSum", [one_hot_node.output[0]], attr={'axes': [1], 'keepdims': 0}) - scaling_node = ctx.make_node("Sqrt", [seg_cnts_node.output[0]]) + seg_cnts_node_output = GraphBuilder(ctx).make_reduce_sum( + {"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1}) + scaling_node_output = ctx.make_node("Sqrt", [seg_cnts_node_output]).output[0] else: - scaling_node = None + scaling_node_output = None - if scaling_node and num_segments_specified: + if scaling_node_output is not None and num_segments_specified: # If empty segments are possible, we must avoid division by zero const_one_float = ctx.make_const(utils.make_name("const_one_float"), np.array(1, dtype=np.float32)) - scaling_node = ctx.make_node("Max", [scaling_node.output[0], const_one_float.output[0]]) + scaling_node_output = ctx.make_node("Max", [scaling_node_output, const_one_float.output[0]]).output[0] if onnx_op == "ReduceSum": @@ -226,8 +243,8 @@ def version_9(cls, ctx, node, **kwargs): # Shapes [s, n] * [n, P] => [s, P] product = ctx.make_node("MatMul", [one_hot_cast.output[0], data_reshape.output[0]], op_name_scope=node.name) - if scaling_node is not None: - scaling_node_unsqueeze = ctx.make_node("Unsqueeze", [scaling_node.output[0]], attr={'axes': [1]}) + if scaling_node_output is not None: + scaling_node_unsqueeze = ctx.make_node("Unsqueeze", [scaling_node_output], attr={'axes': [1]}) product = ctx.make_node("Div", [product.output[0], scaling_node_unsqueeze.output[0]]) # Create new shape [0, a, b, ..., c] diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index ba1f771a1..cb52196ca 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -183,28 +183,25 @@ class Squeeze: def version_1(cls, ctx, node, **kwargs): # T output = Squeeze(T input, @list(int) squeeze_dims) # T squeezed = Squeeze(T data, @AttrType.INTS axes), axes are list of positive integers. - axis = node.get_attr("axis") - if not axis: - axis = node.get_attr("squeeze_dims") - if axis: - del node.attr["squeeze_dims"] + axes = node.get_attr_value("squeeze_dims") + if axes is None: + axes = [] else: - del node.attr["axis"] + del node.attr["squeeze_dims"] - if axis and axis.ints: - axis = axis.ints - neg_axis = any([val < 0 for val in axis]) - if neg_axis: + # TF uses empty axes to indicate that all 1 dims should be squeezed + if len(axes) > 0: + neg_axis = any([val < 0 for val in axes]) + if neg_axis and ctx.opset < 11: shape = ctx.get_shape(node.input[0]) - utils.make_sure(shape is not None, "squeeze input shape cannot be None") + utils.make_sure(shape is not None, "squeeze with negative axes and unknown rank requires opset >= 11") shape_len = len(shape) - axis = [a + shape_len if a < 0 else a for a in axis] - else: - shape = ctx.get_shape(node.input[0]) - utils.make_sure(shape is not None, "squeeze input shape cannot be None") - axis = [i for i, j in enumerate(shape) if j == 1] - if not axis: axis = [0] - node.set_attr("axes", axis) + axes = [a + shape_len if a < 0 else a for a in axes] + if ctx.opset < 13: + node.set_attr("axes", axes) + else: + axes_const = ctx.make_const(utils.make_name("axes_const"), np.array(axes, dtype=np.int64)) + ctx.replace_inputs(node, [node.input[0], axes_const.output[0]]) @classmethod def version_11(cls, ctx, node, **kwargs):