18
18
from tf2onnx .tflite .TensorType import TensorType as TFLiteTensorType
19
19
from tf2onnx .tflite .Model import Model
20
20
from tf2onnx .flexbuffers import read_flexbuffer
21
- from tf2onnx .tf_utils import read_tf_node_attrs , read_tf_node_def_attrs
21
+ from tf2onnx .tf_utils import read_tf_node_def_attrs
22
22
from tf2onnx import utils
23
23
24
24
logger = logging .getLogger (__name__ )
@@ -126,6 +126,7 @@ def get_options_class(name):
126
126
def read_tflite_model (tflite_path ):
127
127
"""
128
128
Given the path to a tflite model, returns tuple (tflite_graphs, opcodes_map, model)
129
+ Graphs are topologically sorted and the main graph is last
129
130
Pass these to parse_tflite_graph
130
131
"""
131
132
with open (tflite_path , 'rb' ) as f :
@@ -160,14 +161,16 @@ def read_tflite_model(tflite_path):
160
161
tensor_shapes [name ] = details ["shape" ].tolist ()
161
162
except Exception as e : # pylint: disable=broad-except
162
163
logger .warning ("Error loading model into tflite interpreter: %s" , e )
164
+ tflite_graphs = get_model_subgraphs (model )
163
165
return tflite_graphs , opcodes_map , model , tensor_shapes
164
166
165
167
166
- def get_subgraph_dependencies (tflite_g , opcodes_map , model ):
167
- """Returns a list of subgraph names referenced by the provided graph"""
168
+ def get_subgraph_dependencies (model , graph_idx ):
169
+ """Returns a list of subgraph indices referenced by the indicated graph"""
168
170
dependencies = []
169
- for i in range (tflite_g .OperatorsLength ()):
170
- op = tflite_g .Operators (i )
171
+ g = model .Subgraphs (graph_idx )
172
+ for i in range (g .OperatorsLength ()):
173
+ op = g .Operators (i )
171
174
options_type_name = lookup_enum (op .BuiltinOptionsType (), 'BuiltinOptions' )
172
175
option_class = get_options_class (options_type_name )
173
176
if option_class is not None :
@@ -176,21 +179,20 @@ def get_subgraph_dependencies(tflite_g, opcodes_map, model):
176
179
for attr in FUNCTION_ATTRS :
177
180
if hasattr (options , attr ):
178
181
value = getattr (options , attr )()
179
- dependencies .append (model . Subgraphs ( value ). Name (). decode () )
182
+ dependencies .append (value )
180
183
return dependencies
181
184
182
185
183
- def topsort_tfl_subgraphs ( tflite_graphs , opcodes_map , model ):
186
+ def get_model_subgraphs ( model ):
184
187
"""Returns topologically sorted subgraphs of a model. Guarantees main graph is placed at the end."""
185
- main_g = tflite_graphs [ 0 ]. Name (). decode ()
188
+ main_g = 0
186
189
dependencies = {}
187
- name_to_graph = {}
188
- for g in tflite_graphs :
189
- name = g .Name ().decode ()
190
- name_to_graph [name ] = g
191
- ds = get_subgraph_dependencies (g , opcodes_map , model )
192
- utils .make_sure (main_g not in ds , "Main graph %s is a dependency of subgraph %s" , main_g , name )
193
- dependencies [name ] = ds
190
+ idx_to_graph = {}
191
+ for i in range (model .SubgraphsLength ()):
192
+ idx_to_graph [i ] = model .Subgraphs (i )
193
+ ds = get_subgraph_dependencies (model , i )
194
+ utils .make_sure (main_g not in ds , "Main graph %s is a dependency of subgraph %s" , main_g , i )
195
+ dependencies [i ] = ds
194
196
195
197
ordered = []
196
198
visited = set ()
@@ -206,10 +208,10 @@ def visit(g):
206
208
ordered .append (g )
207
209
visiting .remove (g )
208
210
209
- for g in reversed (tflite_graphs ):
210
- visit (g . Name (). decode () )
211
+ for g in reversed (range ( model . SubgraphsLength ()) ):
212
+ visit (g )
211
213
212
- return [name_to_graph [ n ] for n in ordered ]
214
+ return [idx_to_graph [ i ] for i in ordered ]
213
215
214
216
215
217
def get_quantization_attr (quant_params ):
0 commit comments