Skip to content

Commit 4bf83c8

Browse files
bedapislBedrich
authored andcommitted
Add cast to same type before equal operator
1 parent 6ec695b commit 4bf83c8

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

tests/test_equal.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed under the MIT license.
2+
3+
"""Unit tests for equal"""
4+
5+
from __future__ import absolute_import
6+
from __future__ import division
7+
from __future__ import print_function
8+
9+
import numpy as np
10+
import tensorflow as tf
11+
12+
from backend_test_base import Tf2OnnxBackendTestBase
13+
14+
15+
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
16+
# pylint: disable=abstract-method,arguments-differ
17+
18+
class EqualTests(Tf2OnnxBackendTestBase):
19+
20+
def test_equal_with_different_parameters(self):
21+
input_val = np.array([5], dtype=np.int32)
22+
23+
def func(input_val):
24+
tensor = tf.zeros(input_val)
25+
input_size = tf.size(tensor)
26+
constant = tf.constant(3, dtype=tf.int32)
27+
return tf.math.equal(input_size, constant, name="output")
28+
29+
feed_dict = {"input:0": input_val}
30+
input_names_with_port = ["input:0"]
31+
output_names_with_port = ["output:0"]
32+
33+
current_opset = self.config.opset
34+
self.config.opset = 12
35+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
36+
self.config.opset = current_opset
37+
38+
if __name__ == '__main__':
39+
unittest_main()

tf2onnx/onnx_opset/logical.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
3434
graph.set_dtype(inp_cast.output[0], target_dtype)
3535

3636

37+
def _add_cast_to_same_type_to_inputs(graph, node):
38+
common_dtype = graph.get_dtype(node.input[0])
39+
40+
for inp in node.input[1:]:
41+
if graph.get_dtype(inp) != common_dtype:
42+
inp_cast = graph.insert_new_node_on_input(node, "Cast", inp, to=common_dtype)
43+
graph.copy_shape(inp, inp_cast.output[0])
44+
graph.set_dtype(inp_cast.output[0], common_dtype)
45+
46+
3747
@tf_op("LogicalNot", onnx_op="Not")
3848
class DirectOp:
3949
@classmethod
@@ -81,7 +91,8 @@ def version_7(cls, ctx, node, **kwargs):
8191

8292
@classmethod
8393
def version_11(cls, ctx, node, **kwargs):
84-
# starting with opset-11, equal supports all types
94+
# starting with opset-11, equal supports all types (but both operands must be of the same type)
95+
_add_cast_to_same_type_to_inputs(ctx, node)
8596
need_not = node.type == "NotEqual"
8697
if need_not:
8798
node.type = "Equal"

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _wrap_concat_with_cast(ctx, node):
6161
class Size:
6262
@classmethod
6363
def version_1(cls, ctx, node, **kwargs):
64-
pass
64+
ctx.set_dtype(node.output[0], onnx_pb.TensorProto.INT64)
6565

6666

6767
@tf_op("Flatten")

0 commit comments

Comments
 (0)