Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _run_backend(self, g, outputs, input_dict):
raise ValueError("unknown backend")
return y

def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=0.,
def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=False,
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None):
# optional - passed to process_tf_graph
Expand Down
3 changes: 3 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,7 @@ def test_sign(self):
self._run_test_case([_OUTPUT], {_INPUT: x_val})
tf.reset_default_graph()

@check_target("rs6", "onehot")
def test_onehot0(self):
x_val = np.array([0, 1, 2], dtype=np.int32)
depth = 5
Expand All @@ -1138,6 +1139,7 @@ def test_onehot1(self):
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val})

@check_target("rs6", "onehot")
def test_onehot2(self):
for axis in [-1, 0, 1]:
tf.reset_default_graph()
Expand All @@ -1148,6 +1150,7 @@ def test_onehot2(self):
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val})

@check_target("rs6", "onehot")
@check_opset_min_version(9, "onehot")
def test_onehot3(self):
# rank 1
Expand Down
17 changes: 16 additions & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
raise ValueError("only onnxruntime is supported to test transpose optimizer")

for expected_val, actual_val in zip(expected, actual):
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=0.)
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=1e-5)
self.assertEqual(expected_val.dtype, actual_val.dtype)
self.assertEqual(expected_val.shape, actual_val.shape)

Expand Down Expand Up @@ -147,6 +147,21 @@ def test_transpose_with_shape(self):
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
model_proto, remaining_transpose_num=0)

def test_transpose_with_identity(self):
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
node2 = helper.make_node("Identity", ["Y"], ["Z"], name="identity")

graph = helper.make_graph(
[node1, node2],
"transpose_with_identity",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (2, 4, 5, 3))],
)

model_proto = helper.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
model_proto, remaining_transpose_num=1)

# Tranpose Optimizer Tests End

# Identity Optimizer Tests Start
Expand Down
10 changes: 9 additions & 1 deletion tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,13 @@ def _transpose_handler(self, trans, node):
ops = self._g.get_nodes()
self._g.replace_all_inputs(ops, node.output[0], trans.input[0])

shape = self._g.get_shape(node.output[0])
dtype = self._g.get_dtype(node.output[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add tests to cover these cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 transpose node is graph output has been covered by test_transpose_leaky_relu
2 identity is graph output is covered by new-added test case "test_transpose_with_identity"

self._g.remove_node(trans.name)
self._g.remove_node(node.name)
if node.output[0] in self._g.outputs:
self._g.make_node("Identity", [trans.input[0]],
outputs=node.output, shapes=[shape], dtypes=[dtype])
return True
return False

Expand Down Expand Up @@ -380,7 +385,8 @@ def _mul_handler(self, trans, node):
multiplier_input_id = i
multiplier_input_node = input_node

if not multiplier_input_node.is_const():
# node's inputs may come from one same node. if so the multiplier_input_node may be none
if multiplier_input_node is None or not multiplier_input_node.is_const():
return False
multiplier = multiplier_input_node.get_tensor_value(as_list=False)

Expand Down Expand Up @@ -408,6 +414,8 @@ def _mul_handler(self, trans, node):
return False

def _identity_handler(self, trans, node):
if node.output[0] in self._g.outputs:
return False
ops = self._g.get_nodes()
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
self._g.remove_node(node.name)
Expand Down