13
13
from distutils .version import LooseVersion
14
14
15
15
import numpy as np
16
- import six
17
16
import tensorflow as tf
18
17
19
18
from tensorflow .core .framework import types_pb2 , tensor_pb2
@@ -70,7 +69,7 @@ def get_tf_tensor_data(tensor):
70
69
"""Get data from tensor."""
71
70
make_sure (isinstance (tensor , tensor_pb2 .TensorProto ), "Require TensorProto" )
72
71
np_data = tensor_util .MakeNdarray (tensor )
73
- make_sure (isinstance (np_data , np .ndarray ), "{} isn't ndarray" . format ( np_data ) )
72
+ make_sure (isinstance (np_data , np .ndarray ), "%r isn't ndarray" , np_data )
74
73
return np_data
75
74
76
75
@@ -83,7 +82,7 @@ def get_tf_const_value(op, as_list=True):
83
82
when as_list=False, return np.array(1), type is <class 'numpy.ndarray'>
84
83
when as_list=True, return 1, type is <class 'int'>.
85
84
"""
86
- make_sure (is_tf_const_op (op ), "{} isn't a const op" . format ( op .name ) )
85
+ make_sure (is_tf_const_op (op ), "%r isn't a const op" , op .name )
87
86
value = get_tf_tensor_data (op .get_attr ("value" ))
88
87
if as_list :
89
88
value = value .tolist ()
@@ -119,9 +118,6 @@ def map_tf_dtype(dtype):
119
118
120
119
def get_tf_node_attr (node , name ):
121
120
"""Parser TF node attribute."""
122
- if six .PY2 :
123
- # For python2, TF get_attr does not accept unicode
124
- name = str (name )
125
121
return node .get_attr (name )
126
122
127
123
@@ -136,14 +132,14 @@ def tflist_to_onnx(g, shape_override):
136
132
"""
137
133
138
134
# ignore the following attributes
139
- ignored_attr = [ "unknown_rank" , "_class" , "Tshape" , "use_cudnn_on_gpu" , "Index" , "Tpaddings" ,
135
+ ignored_attr = { "unknown_rank" , "_class" , "Tshape" , "use_cudnn_on_gpu" , "Index" , "Tpaddings" ,
140
136
"TI" , "Tparams" , "Tindices" , "Tlen" , "Tdim" , "Tin" , "dynamic_size" , "Tmultiples" ,
141
137
"Tblock_shape" , "Tcrops" , "index_type" , "Taxis" , "U" , "maxval" ,
142
138
"Tout" , "Tlabels" , "Tindex" , "element_shape" , "Targmax" , "Tperm" , "Tcond" ,
143
139
"T_threshold" , "element_dtype" , "shape_type" , "_lower_using_switch_merge" ,
144
140
"parallel_iterations" , "_num_original_outputs" , "output_types" , "output_shapes" ,
145
141
"key_dtype" , "value_dtype" , "Tin" , "Tout" , "capacity" , "component_types" , "shapes" ,
146
- "Toutput_types" ]
142
+ "Toutput_types" }
147
143
148
144
node_list = g .get_operations ()
149
145
functions = {}
@@ -176,12 +172,11 @@ def tflist_to_onnx(g, shape_override):
176
172
attr_cnt [a ] += 1
177
173
if a == "dtype" :
178
174
attr [a ] = map_tf_dtype (get_tf_node_attr (node , "dtype" ))
179
- elif a in [ "T" ] :
175
+ elif a == "T" :
180
176
dtype = get_tf_node_attr (node , a )
181
- if dtype :
182
- if not isinstance (dtype , list ):
183
- dtypes [node .name ] = map_tf_dtype (dtype )
184
- elif a in ["output_type" , "output_dtype" , "out_type" , "Tidx" , "out_idx" ]:
177
+ if dtype and not isinstance (dtype , list ):
178
+ dtypes [node .name ] = map_tf_dtype (dtype )
179
+ elif a in {"output_type" , "output_dtype" , "out_type" , "Tidx" , "out_idx" }:
185
180
# Tidx is used by Range
186
181
# out_idx is used by ListDiff
187
182
attr [a ] = map_tf_dtype (get_tf_node_attr (node , a ))
@@ -192,7 +187,7 @@ def tflist_to_onnx(g, shape_override):
192
187
elif a == "output_shapes" :
193
188
# we should not need it since we pull the shapes above already
194
189
pass
195
- elif a in [ "body" , "cond" , "then_branch" , "else_branch" ] :
190
+ elif a in { "body" , "cond" , "then_branch" , "else_branch" } :
196
191
input_shapes = [inp .get_shape () for inp in node .inputs ]
197
192
nattr = get_tf_node_attr (node , a )
198
193
attr [a ] = nattr .name
0 commit comments