Skip to content

Commit 2eafd21

Browse files
Merge pull request #1228 from onnx/tom/RecursiveTableInfoSearch
Added recursive search for table info of saved models
2 parents ba4ce36 + 0dbbd25 commit 2eafd21

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

tf2onnx/tf_loader.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def not_implemented_tf_placeholder(*args, **kwargs):
4949
# pylint: disable=protected-access
5050
from tensorflow.python.saved_model.load import _RestoredResource as TfRestoredResourceType
5151
except ImportError:
52-
TfRestoredResourceType = None
52+
TfRestoredResourceType = tuple() # isinstance(x, tuple()) is always false
53+
54+
try:
55+
from tensorflow.python.training.tracking.tracking import AutoTrackable as TfAutoTrackableType
56+
except ImportError:
57+
TfAutoTrackableType = tuple()
5358

5459
if is_tf2():
5560
convert_variables_to_constants = tf.compat.v1.graph_util.convert_variables_to_constants
@@ -266,6 +271,25 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
266271
return frozen_graph, input_names, output_names
267272

268273

274+
def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
275+
removed_resource_to_placeholder, placeholder_to_table_info):
276+
# pylint: disable=protected-access
277+
for r in trackable.__dict__.values():
278+
if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource') and hasattr(r, 'resource_handle'):
279+
initializer = r._create_resource.concrete_functions[0].function_def
280+
new_names, new_k_dtypes, new_v_dtypes = get_hash_table_info(initializer.node_def)
281+
table_names.extend(new_names)
282+
key_dtypes.extend(new_k_dtypes)
283+
value_dtypes.extend(new_v_dtypes)
284+
table_handle = id(r.resource_handle)
285+
if table_handle in removed_resource_to_placeholder and len(new_names) == 1:
286+
table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
287+
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
288+
if isinstance(r, TfAutoTrackableType):
289+
_get_hash_table_info_from_trackable(r, table_names, key_dtypes, value_dtypes,
290+
removed_resource_to_placeholder, placeholder_to_table_info)
291+
292+
269293
def _remove_non_variable_resources_from_captures(concrete_func):
270294
"""
271295
Removes all non-variable resources (such as tables) from a function's captured inputs to prevent tf from
@@ -370,19 +394,8 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
370394

371395
table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
372396
placeholder_to_table_info = {}
373-
for r in imported.__dict__.values():
374-
if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource') and hasattr(r, 'resource_handle'):
375-
# Add tables from saved_model table initializers
376-
# pylint: disable=protected-access
377-
initializer = r._create_resource.concrete_functions[0].function_def
378-
new_names, new_k_dtypes, new_v_dtypes = get_hash_table_info(initializer.node_def)
379-
table_names.extend(new_names)
380-
key_dtypes.extend(new_k_dtypes)
381-
value_dtypes.extend(new_v_dtypes)
382-
table_handle = id(r.resource_handle)
383-
if table_handle in removed_resource_to_placeholder and len(new_names) == 1:
384-
table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
385-
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
397+
_get_hash_table_info_from_trackable(imported, table_names, key_dtypes, value_dtypes,
398+
removed_resource_to_placeholder, placeholder_to_table_info)
386399

387400
initialized_tables = {}
388401
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
@@ -393,6 +406,10 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
393406
except Exception: # pylint: disable=broad-except
394407
logger.warning("Could not initialize table with shared_name = %r", n)
395408

409+
for placeholder in removed_resource_to_placeholder.values():
410+
if placeholder not in placeholder_to_table_info:
411+
logger.error("Could not find table resource to replace placeholder %s", placeholder)
412+
396413
replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)
397414

398415
return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables

0 commit comments

Comments
 (0)