Skip to content

Commit 0c5f97d

Browse files
authored
Merge pull request #1054 from xadupre/perf2
Perf gain in tf_utils.py, more efficient error messages, faster comparisons
2 parents 0110037 + 75fe90d commit 0c5f97d

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

tf2onnx/tf_utils.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from distutils.version import LooseVersion
1414

1515
import numpy as np
16-
import six
1716
import tensorflow as tf
1817

1918
from tensorflow.core.framework import types_pb2, tensor_pb2
@@ -70,7 +69,7 @@ def get_tf_tensor_data(tensor):
7069
"""Get data from tensor."""
7170
make_sure(isinstance(tensor, tensor_pb2.TensorProto), "Require TensorProto")
7271
np_data = tensor_util.MakeNdarray(tensor)
73-
make_sure(isinstance(np_data, np.ndarray), "{} isn't ndarray".format(np_data))
72+
make_sure(isinstance(np_data, np.ndarray), "%r isn't ndarray", np_data)
7473
return np_data
7574

7675

@@ -83,7 +82,7 @@ def get_tf_const_value(op, as_list=True):
8382
when as_list=False, return np.array(1), type is <class 'numpy.ndarray'>
8483
when as_list=True, return 1, type is <class 'int'>.
8584
"""
86-
make_sure(is_tf_const_op(op), "{} isn't a const op".format(op.name))
85+
make_sure(is_tf_const_op(op), "%r isn't a const op", op.name)
8786
value = get_tf_tensor_data(op.get_attr("value"))
8887
if as_list:
8988
value = value.tolist()
@@ -119,9 +118,6 @@ def map_tf_dtype(dtype):
119118

120119
def get_tf_node_attr(node, name):
121120
"""Parser TF node attribute."""
122-
if six.PY2:
123-
# For python2, TF get_attr does not accept unicode
124-
name = str(name)
125121
return node.get_attr(name)
126122

127123

@@ -136,14 +132,14 @@ def tflist_to_onnx(g, shape_override):
136132
"""
137133

138134
# ignore the following attributes
139-
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
135+
ignored_attr = {"unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
140136
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "Tin", "dynamic_size", "Tmultiples",
141137
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
142138
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "Tperm", "Tcond",
143139
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
144140
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
145141
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
146-
"Toutput_types"]
142+
"Toutput_types"}
147143

148144
node_list = g.get_operations()
149145
functions = {}
@@ -176,12 +172,11 @@ def tflist_to_onnx(g, shape_override):
176172
attr_cnt[a] += 1
177173
if a == "dtype":
178174
attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype"))
179-
elif a in ["T"]:
175+
elif a == "T":
180176
dtype = get_tf_node_attr(node, a)
181-
if dtype:
182-
if not isinstance(dtype, list):
183-
dtypes[node.name] = map_tf_dtype(dtype)
184-
elif a in ["output_type", "output_dtype", "out_type", "Tidx", "out_idx"]:
177+
if dtype and not isinstance(dtype, list):
178+
dtypes[node.name] = map_tf_dtype(dtype)
179+
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx"}:
185180
# Tidx is used by Range
186181
# out_idx is used by ListDiff
187182
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
@@ -192,7 +187,7 @@ def tflist_to_onnx(g, shape_override):
192187
elif a == "output_shapes":
193188
# we should not need it since we pull the shapes above already
194189
pass
195-
elif a in ["body", "cond", "then_branch", "else_branch"]:
190+
elif a in {"body", "cond", "then_branch", "else_branch"}:
196191
input_shapes = [inp.get_shape() for inp in node.inputs]
197192
nattr = get_tf_node_attr(node, a)
198193
attr[a] = nattr.name

0 commit comments

Comments
 (0)