diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index b45f07c63..2adf7acea 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -91,7 +91,7 @@ def _run_backend(self, g, outputs, input_dict): raise ValueError("unknown backend") return y - def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=0., + def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5, convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=False, check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None): # optional - passed to process_tf_graph diff --git a/tests/test_backend.py b/tests/test_backend.py index 4e9b5de91..6bf9761cd 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1118,6 +1118,7 @@ def test_sign(self): self._run_test_case([_OUTPUT], {_INPUT: x_val}) tf.reset_default_graph() + @check_target("rs6", "onehot") def test_onehot0(self): x_val = np.array([0, 1, 2], dtype=np.int32) depth = 5 @@ -1138,6 +1139,7 @@ def test_onehot1(self): _ = tf.identity(x_, name=_TFOUTPUT) self._run_test_case([_OUTPUT], {_INPUT: x_val}) + @check_target("rs6", "onehot") def test_onehot2(self): for axis in [-1, 0, 1]: tf.reset_default_graph() @@ -1148,6 +1150,7 @@ def test_onehot2(self): _ = tf.identity(x_, name=_TFOUTPUT) self._run_test_case([_OUTPUT], {_INPUT: x_val}) + @check_target("rs6", "onehot") @check_opset_min_version(9, "onehot") def test_onehot3(self): # rank 1 diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 15f218cd3..b10deee61 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -45,7 +45,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto, raise ValueError("only onnxruntime is supported to test transpose optimizer") for expected_val, actual_val in zip(expected, actual): - self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=0.) + self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=1e-5) self.assertEqual(expected_val.dtype, actual_val.dtype) self.assertEqual(expected_val.shape, actual_val.shape) @@ -147,6 +147,21 @@ def test_transpose_with_shape(self): self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)}, model_proto, remaining_transpose_num=0) + def test_transpose_with_identity(self): + node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans") + node2 = helper.make_node("Identity", ["Y"], ["Z"], name="identity") + + graph = helper.make_graph( + [node1, node2], + "transpose_with_identity", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))], + [helper.make_tensor_value_info("Z", TensorProto.FLOAT, (2, 4, 5, 3))], + ) + + model_proto = helper.make_model(graph, producer_name="onnx-tests") + self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)}, + model_proto, remaining_transpose_num=1) + # Tranpose Optimizer Tests End # Identity Optimizer Tests Start diff --git a/tf2onnx/optimizer/transpose_optimizer.py b/tf2onnx/optimizer/transpose_optimizer.py index fd5ff2b55..4689f8f07 100644 --- a/tf2onnx/optimizer/transpose_optimizer.py +++ b/tf2onnx/optimizer/transpose_optimizer.py @@ -340,8 +340,13 @@ def _transpose_handler(self, trans, node): ops = self._g.get_nodes() self._g.replace_all_inputs(ops, node.output[0], trans.input[0]) + shape = self._g.get_shape(node.output[0]) + dtype = self._g.get_dtype(node.output[0]) self._g.remove_node(trans.name) self._g.remove_node(node.name) + if node.output[0] in self._g.outputs: + self._g.make_node("Identity", [trans.input[0]], + outputs=node.output, shapes=[shape], dtypes=[dtype]) return True return False @@ -380,7 +385,8 @@ def _mul_handler(self, trans, node): multiplier_input_id = i multiplier_input_node = input_node - if not multiplier_input_node.is_const(): + # node's inputs may come from one same node. if so the multiplier_input_node may be none + if multiplier_input_node is None or not multiplier_input_node.is_const(): return False multiplier = multiplier_input_node.get_tensor_value(as_list=False) @@ -408,6 +414,8 @@ def _mul_handler(self, trans, node): return False def _identity_handler(self, trans, node): + if node.output[0] in self._g.outputs: + return False ops = self._g.get_nodes() self._g.replace_all_inputs(ops, node.output[0], trans.output[0]) self._g.remove_node(node.name)