Skip to content

Commit 60b6813

Browse files
committed
minor modifications
1 parent 3052221 commit 60b6813

File tree

6 files changed

+35
-20
lines changed

6 files changed

+35
-20
lines changed

tests/backend_test_base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
import tensorflow as tf
2222
from tensorflow.python.ops import variables as variables_lib
23+
import onnx
2324
from common import get_test_config
2425
from tf2onnx import utils
2526
from tf2onnx.tfonnx import process_tf_graph
@@ -63,12 +64,23 @@ def run_onnxcaffe2(self, onnx_graph, inputs):
6364
def run_onnxruntime(self, model_path, inputs, output_names):
6465
"""Run test against onnxruntime backend."""
6566
import onnxruntime as rt
67+
try:
68+
from onnxruntime.capi.onnxruntime_pybind11_state import Fail
69+
except ImportError:
70+
Fail = RuntimeError
6671
opt = rt.SessionOptions()
6772
# in case of issues with the runtime, one can enable more logging
6873
# opt.log_severity_level = 0
6974
# opt.log_verbosity_level = 255
7075
# opt.enable_profiling = True
71-
m = rt.InferenceSession(model_path, opt)
76+
with open(model_path, 'rb') as f:
77+
onx = onnx.load(f)
78+
try:
79+
m = rt.InferenceSession(model_path, opt)
80+
except Fail as e:
81+
with open(model_path, 'rb') as f:
82+
onx = onnx.load(f)
83+
raise AssertionError("Unable to load model '{}'\n{}".format(model_path, onx)) from e
7284
results = m.run(output_names, inputs)
7385
return results
7486

tf2onnx/graph.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def input(self):
5656

5757
@input.setter
5858
def input(self, val):
59+
# The setter can catch that all inputs are change
60+
# but it cannot catch that one input is changed.
61+
# That's method replace_input and replace_inputs must
62+
# be used to change inputs to let the graph instance
63+
# update its internal indices.
5964
self._input = copy.deepcopy(val)
6065

6166
@property
@@ -988,7 +993,7 @@ def _get_unvisited_child(g, node, not_visited):
988993
all_input = list(filter(lambda a: a != '', all_input))
989994
for inp in sorted(all_input):
990995
j = self.get_node_by_output(inp)
991-
utils.make_sure(j is not None, "Cannot find node with output {}".format(inp))
996+
utils.make_sure(j is not None, "Cannot find node with output %r", inp)
992997
if self.parent_graph and j.name not in op_name_to_index:
993998
# there might be some outer-scoped inputs for an inner Graph.
994999
pass
@@ -1294,6 +1299,8 @@ def replace_all_inputs(self, ops, old_input, new_input):
12941299
to_ops = set()
12951300

12961301
for node in ops:
1302+
if node is None:
1303+
continue
12971304
if old_input in node.input and new_input in node.output:
12981305
raise RuntimeError("creating a circle in the graph is not allowed: " + node.name)
12991306
self._input_to_node_name[new_input].add(node.name)

tf2onnx/graph_matcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, op_type, name=None, inputs=None):
5050
input_pattern if isinstance(input_pattern, OpTypePattern) else
5151
OpTypePattern(input_pattern) for input_pattern in inputs
5252
]
53+
self.op_type_set = set(op_type.split('|')) if op_type else set()
5354

5455
@property
5556
def op_type(self):
@@ -154,7 +155,7 @@ def _is_op_type_same(op, pattern):
154155
if pattern.op_type == "*":
155156
return True
156157

157-
if op.type in pattern.op_type.split('|'):
158+
if op.type in pattern.op_type_set:
158159
return True
159160

160161
return False

tf2onnx/onnx_opset/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def version_8(cls, ctx, node, **kwargs):
183183

184184
@classmethod
185185
def version_12(cls, ctx, node, **kwargs):
186-
node.name = 'Clip' # clip supports all types now
186+
node.type = 'Clip' # clip supports all types now
187187

188188
@tf_op("Softmax")
189189
class Softmax:

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,13 @@ def version_1(cls, ctx, node, **kwargs):
272272
raise RuntimeError("all inputs of {} are empty".format(node.name))
273273

274274
axis_node = node.inputs[-1]
275-
utils.make_sure(axis_node.is_const(), "{} needs to be const".format(axis_node.name))
275+
utils.make_sure(axis_node.is_const(), "%r needs to be const", axis_node.name)
276276
axis_val = axis_node.get_tensor_value()
277277
ctx.remove_input(node, node.input[-1])
278278

279279
if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports.
280280
input_shape = ctx.get_shape(node.input[0])
281-
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
281+
utils.make_sure(input_shape is not None, "shape of %r is None", node.input[0])
282282
axis_val = len(input_shape) + axis_val
283283
node.set_attr("axis", axis_val)
284284

tf2onnx/tf_utils.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from distutils.version import LooseVersion
1414

1515
import numpy as np
16-
import six
1716
import tensorflow as tf
1817

1918
from tensorflow.core.framework import types_pb2, tensor_pb2
@@ -70,7 +69,7 @@ def get_tf_tensor_data(tensor):
7069
"""Get data from tensor."""
7170
make_sure(isinstance(tensor, tensor_pb2.TensorProto), "Require TensorProto")
7271
np_data = tensor_util.MakeNdarray(tensor)
73-
make_sure(isinstance(np_data, np.ndarray), "{} isn't ndarray".format(np_data))
72+
make_sure(isinstance(np_data, np.ndarray), "%r isn't ndarray", np_data)
7473
return np_data
7574

7675

@@ -83,7 +82,7 @@ def get_tf_const_value(op, as_list=True):
8382
when as_list=False, return np.array(1), type is <class 'numpy.ndarray'>
8483
when as_list=True, return 1, type is <class 'int'>.
8584
"""
86-
make_sure(is_tf_const_op(op), "{} isn't a const op".format(op.name))
85+
make_sure(is_tf_const_op(op), "%r isn't a const op", op.name)
8786
value = get_tf_tensor_data(op.get_attr("value"))
8887
if as_list:
8988
value = value.tolist()
@@ -119,9 +118,6 @@ def map_tf_dtype(dtype):
119118

120119
def get_tf_node_attr(node, name):
121120
"""Parser TF node attribute."""
122-
if six.PY2:
123-
# For python2, TF get_attr does not accept unicode
124-
name = str(name)
125121
return node.get_attr(name)
126122

127123

@@ -136,14 +132,14 @@ def tflist_to_onnx(g, shape_override):
136132
"""
137133

138134
# ignore the following attributes
139-
ignored_attr = ["unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
135+
ignored_attr = {"unknown_rank", "_class", "Tshape", "use_cudnn_on_gpu", "Index", "Tpaddings",
140136
"TI", "Tparams", "Tindices", "Tlen", "Tdim", "Tin", "dynamic_size", "Tmultiples",
141137
"Tblock_shape", "Tcrops", "index_type", "Taxis", "U", "maxval",
142138
"Tout", "Tlabels", "Tindex", "element_shape", "Targmax", "Tperm", "Tcond",
143139
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
144140
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
145141
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
146-
"Toutput_types"]
142+
"Toutput_types"}
147143

148144
node_list = g.get_operations()
149145
functions = {}
@@ -176,12 +172,11 @@ def tflist_to_onnx(g, shape_override):
176172
attr_cnt[a] += 1
177173
if a == "dtype":
178174
attr[a] = map_tf_dtype(get_tf_node_attr(node, "dtype"))
179-
elif a in ["T"]:
175+
elif a == "T":
180176
dtype = get_tf_node_attr(node, a)
181-
if dtype:
182-
if not isinstance(dtype, list):
183-
dtypes[node.name] = map_tf_dtype(dtype)
184-
elif a in ["output_type", "output_dtype", "out_type", "Tidx", "out_idx"]:
177+
if dtype and not isinstance(dtype, list):
178+
dtypes[node.name] = map_tf_dtype(dtype)
179+
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx"}:
185180
# Tidx is used by Range
186181
# out_idx is used by ListDiff
187182
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
@@ -192,7 +187,7 @@ def tflist_to_onnx(g, shape_override):
192187
elif a == "output_shapes":
193188
# we should not need it since we pull the shapes above already
194189
pass
195-
elif a in ["body", "cond", "then_branch", "else_branch"]:
190+
elif a in {"body", "cond", "then_branch", "else_branch"}:
196191
input_shapes = [inp.get_shape() for inp in node.inputs]
197192
nattr = get_tf_node_attr(node, a)
198193
attr[a] = nattr.name

0 commit comments

Comments
 (0)