From 528c97f41bbda082bf728d37440bd88a85c1dd6e Mon Sep 17 00:00:00 2001 From: hwangdeyu Date: Wed, 22 Dec 2021 07:33:48 +0000 Subject: [PATCH] add unit32 unit64 type support Signed-off-by: hwangdeyu --- tests/test_backend.py | 6 ++++++ tf2onnx/tf_utils.py | 2 ++ tf2onnx/utils.py | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index 9e00f9e56..a5f3cf3cf 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2353,6 +2353,12 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + x_val = np.array([1, 2, 3, 4], dtype=np.uint32).reshape((2, 2)) + def func(x): + x_ = tf.cast(x, tf.uint64) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(7, "sign") def test_sign(self): x_vals = [np.array([1.0, 2.0, 0.0, -1.0, 0.0, -2.0], dtype=np.float32).reshape((2, 3)), diff --git a/tf2onnx/tf_utils.py b/tf2onnx/tf_utils.py index 4dd6eef37..849fc98c4 100644 --- a/tf2onnx/tf_utils.py +++ b/tf2onnx/tf_utils.py @@ -34,6 +34,8 @@ types_pb2.DT_INT8: onnx_pb.TensorProto.INT8, types_pb2.DT_UINT8: onnx_pb.TensorProto.UINT8, types_pb2.DT_UINT16: onnx_pb.TensorProto.UINT16, + types_pb2.DT_UINT32: onnx_pb.TensorProto.UINT32, + types_pb2.DT_UINT64: onnx_pb.TensorProto.UINT64, types_pb2.DT_INT64: onnx_pb.TensorProto.INT64, types_pb2.DT_STRING: onnx_pb.TensorProto.STRING, types_pb2.DT_COMPLEX64: onnx_pb.TensorProto.COMPLEX64, diff --git a/tf2onnx/utils.py b/tf2onnx/utils.py index adb763383..4d5835cd4 100644 --- a/tf2onnx/utils.py +++ b/tf2onnx/utils.py @@ -38,6 +38,8 @@ onnx_pb.TensorProto.INT8: np.int8, onnx_pb.TensorProto.UINT8: np.uint8, onnx_pb.TensorProto.UINT16: np.uint16, + onnx_pb.TensorProto.UINT32: np.uint32, + onnx_pb.TensorProto.UINT64: np.uint64, onnx_pb.TensorProto.INT64: np.int64, onnx_pb.TensorProto.UINT64: np.uint64, onnx_pb.TensorProto.BOOL: np.bool, @@ -58,6 +60,8 @@ onnx_pb.TensorProto.INT8: "int8", onnx_pb.TensorProto.UINT8: "uint8", onnx_pb.TensorProto.UINT16: "uint16", + onnx_pb.TensorProto.UINT32: "uint32", + onnx_pb.TensorProto.UINT64: "uint64", onnx_pb.TensorProto.INT64: "int64", onnx_pb.TensorProto.STRING: "string", onnx_pb.TensorProto.BOOL: "bool",