Skip to content

Commit d24dd82

Browse files
Pylint
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 8770b52 commit d24dd82

File tree

4 files changed

+28
-24
lines changed

4 files changed

+28
-24
lines changed

tests/test_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,7 @@ def func(m):
18491849
feed_dict = {_INPUT: m_val}
18501850
if "input" in g.input_names:
18511851
# TFLite inputs don't have port numbers
1852-
feed_dict = { k.split(":")[0]: v for k, v in feed_dict.items() }
1852+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
18531853
results = self.run_backend(g, g.outputs, feed_dict)
18541854
numbers = set(results[0].flatten())
18551855
self.assertEqual(sorted(numbers), list(range(8)))
@@ -1867,7 +1867,7 @@ def func(n, m):
18671867
feed_dict = {_INPUT: n_val, _INPUT1: m_val}
18681868
if "input" in g.input_names:
18691869
# TFLite inputs don't have port numbers
1870-
feed_dict = { k.split(":")[0]: v for k, v in feed_dict.items() }
1870+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
18711871
results = self.run_backend(g, g.outputs, feed_dict)
18721872
numbers = set(results[0].flatten())
18731873
self.assertEqual(sorted(numbers), list(range(2, 10)))
@@ -1887,7 +1887,7 @@ def func(n, m, s):
18871887
feed_dict = {_INPUT: n_val, _INPUT1: m_val, _INPUT2: s_val}
18881888
if "input" in g.input_names:
18891889
# TFLite inputs don't have port numbers
1890-
feed_dict = { k.split(":")[0]: v for k, v in feed_dict.items() }
1890+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
18911891
results = self.run_backend(g, g.outputs, feed_dict)
18921892
numbers = set(results[0].flatten())
18931893
self.assertEqual(sorted(numbers), list(range(2, 10)))

tf2onnx/tf_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,12 @@ def replace_placeholders_with_tables(graph_def, placeholder_to_table_info):
319319
n.attr['value_dtype'].type = val_dtype
320320

321321
def read_tf_node_def_attrs(node_def, input_dtypes, input_shapes):
322-
from tf2onnx.tf_loader import tf_session, tf_placeholder
322+
"""Given a tf node def, returns a dict of attribute names to values"""
323+
from tf2onnx.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel
323324
del node_def.input[:]
324325
node_def.name = "node"
325326

327+
# read_tf_node_attrs uses some tf methods that require the node to be loaded into a valid TF graph
326328
g = tf.Graph()
327329
with g.as_default():
328330
for i, (dtype, shape) in enumerate(zip(input_dtypes, input_shapes)):
@@ -340,6 +342,7 @@ def read_tf_node_def_attrs(node_def, input_dtypes, input_shapes):
340342

341343

342344
def read_tf_node_attrs(node):
345+
"""Given a tf Node, returns a dict of attribute names to values"""
343346
attr = {}
344347
attr_cnt = collections.Counter()
345348

tf2onnx/tflite_utils.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tf2onnx.tflite.TensorType import TensorType as TFLiteTensorType
1919
from tf2onnx.tflite.Model import Model
2020
from tf2onnx.flexbuffers import read_flexbuffer
21-
from tf2onnx.tf_utils import read_tf_node_attrs, read_tf_node_def_attrs
21+
from tf2onnx.tf_utils import read_tf_node_def_attrs
2222
from tf2onnx import utils
2323

2424
logger = logging.getLogger(__name__)
@@ -126,6 +126,7 @@ def get_options_class(name):
126126
def read_tflite_model(tflite_path):
127127
"""
128128
Given the path to a tflite model, returns tuple (tflite_graphs, opcodes_map, model)
129+
Graphs are topologically sorted and the main graph is last
129130
Pass these to parse_tflite_graph
130131
"""
131132
with open(tflite_path, 'rb') as f:
@@ -160,14 +161,16 @@ def read_tflite_model(tflite_path):
160161
tensor_shapes[name] = details["shape"].tolist()
161162
except Exception as e: # pylint: disable=broad-except
162163
logger.warning("Error loading model into tflite interpreter: %s", e)
164+
tflite_graphs = get_model_subgraphs(model)
163165
return tflite_graphs, opcodes_map, model, tensor_shapes
164166

165167

166-
def get_subgraph_dependencies(tflite_g, opcodes_map, model):
167-
"""Returns a list of subgraph names referenced by the provided graph"""
168+
def get_subgraph_dependencies(model, graph_idx):
169+
"""Returns a list of subgraph indices referenced by the indicated graph"""
168170
dependencies = []
169-
for i in range(tflite_g.OperatorsLength()):
170-
op = tflite_g.Operators(i)
171+
g = model.Subgraphs(graph_idx)
172+
for i in range(g.OperatorsLength()):
173+
op = g.Operators(i)
171174
options_type_name = lookup_enum(op.BuiltinOptionsType(), 'BuiltinOptions')
172175
option_class = get_options_class(options_type_name)
173176
if option_class is not None:
@@ -176,21 +179,20 @@ def get_subgraph_dependencies(tflite_g, opcodes_map, model):
176179
for attr in FUNCTION_ATTRS:
177180
if hasattr(options, attr):
178181
value = getattr(options, attr)()
179-
dependencies.append(model.Subgraphs(value).Name().decode())
182+
dependencies.append(value)
180183
return dependencies
181184

182185

183-
def topsort_tfl_subgraphs(tflite_graphs, opcodes_map, model):
186+
def get_model_subgraphs(model):
184187
"""Returns topologically sorted subgraphs of a model. Guarantees main graph is placed at the end."""
185-
main_g = tflite_graphs[0].Name().decode()
188+
main_g = 0
186189
dependencies = {}
187-
name_to_graph = {}
188-
for g in tflite_graphs:
189-
name = g.Name().decode()
190-
name_to_graph[name] = g
191-
ds = get_subgraph_dependencies(g, opcodes_map, model)
192-
utils.make_sure(main_g not in ds, "Main graph %s is a dependency of subgraph %s", main_g, name)
193-
dependencies[name] = ds
190+
idx_to_graph = {}
191+
for i in range(model.SubgraphsLength()):
192+
idx_to_graph[i] = model.Subgraphs(i)
193+
ds = get_subgraph_dependencies(model, i)
194+
utils.make_sure(main_g not in ds, "Main graph %s is a dependency of subgraph %s", main_g, i)
195+
dependencies[i] = ds
194196

195197
ordered = []
196198
visited = set()
@@ -206,10 +208,10 @@ def visit(g):
206208
ordered.append(g)
207209
visiting.remove(g)
208210

209-
for g in reversed(tflite_graphs):
210-
visit(g.Name().decode())
211+
for g in reversed(range(model.SubgraphsLength())):
212+
visit(g)
211213

212-
return [name_to_graph[n] for n in ordered]
214+
return [idx_to_graph[i] for i in ordered]
213215

214216

215217
def get_quantization_attr(quant_params):

tf2onnx/tfonnx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from tf2onnx.shape_inference import infer_shape
2727
from tf2onnx.tf_loader import is_function, resolve_functions, set_function
2828
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version, compute_const_folding_using_tf
29-
from tf2onnx.tflite_utils import read_tflite_model, parse_tflite_graph, topsort_tfl_subgraphs
29+
from tf2onnx.tflite_utils import read_tflite_model, parse_tflite_graph
3030

3131
from . import constants, logging, schemas, utils, handler
3232

@@ -462,7 +462,6 @@ def rename_tensors_in_nodes(onnx_nodes):
462462

463463
if tflite_path is not None:
464464
tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(tflite_path)
465-
tflite_graphs = topsort_tfl_subgraphs(tflite_graphs, opcodes, model)
466465
main_g = None
467466
inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
468467
for i, tfl_graph in enumerate(tflite_graphs):

0 commit comments

Comments
 (0)