@@ -49,7 +49,12 @@ def not_implemented_tf_placeholder(*args, **kwargs):
49
49
# pylint: disable=protected-access
50
50
from tensorflow .python .saved_model .load import _RestoredResource as TfRestoredResourceType
51
51
except ImportError :
52
- TfRestoredResourceType = None
52
+ TfRestoredResourceType = tuple () # isinstance(x, tuple()) is always false
53
+
54
+ try :
55
+ from tensorflow .python .training .tracking .tracking import AutoTrackable as TfAutoTrackableType
56
+ except ImportError :
57
+ TfAutoTrackableType = tuple ()
53
58
54
59
if is_tf2 ():
55
60
convert_variables_to_constants = tf .compat .v1 .graph_util .convert_variables_to_constants
@@ -266,6 +271,25 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
266
271
return frozen_graph , input_names , output_names
267
272
268
273
274
+ def _get_hash_table_info_from_trackable (trackable , table_names , key_dtypes , value_dtypes ,
275
+ removed_resource_to_placeholder , placeholder_to_table_info ):
276
+ # pylint: disable=protected-access
277
+ for r in trackable .__dict__ .values ():
278
+ if isinstance (r , TfRestoredResourceType ) and hasattr (r , '_create_resource' ) and hasattr (r , 'resource_handle' ):
279
+ initializer = r ._create_resource .concrete_functions [0 ].function_def
280
+ new_names , new_k_dtypes , new_v_dtypes = get_hash_table_info (initializer .node_def )
281
+ table_names .extend (new_names )
282
+ key_dtypes .extend (new_k_dtypes )
283
+ value_dtypes .extend (new_v_dtypes )
284
+ table_handle = id (r .resource_handle )
285
+ if table_handle in removed_resource_to_placeholder and len (new_names ) == 1 :
286
+ table_info = (new_names [0 ], new_k_dtypes [0 ], new_v_dtypes [0 ])
287
+ placeholder_to_table_info [removed_resource_to_placeholder [table_handle ]] = table_info
288
+ if isinstance (r , TfAutoTrackableType ):
289
+ _get_hash_table_info_from_trackable (r , table_names , key_dtypes , value_dtypes ,
290
+ removed_resource_to_placeholder , placeholder_to_table_info )
291
+
292
+
269
293
def _remove_non_variable_resources_from_captures (concrete_func ):
270
294
"""
271
295
Removes all non-variable resources (such as tables) from a function's captured inputs to prevent tf from
@@ -370,19 +394,8 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
370
394
371
395
table_names , key_dtypes , value_dtypes = get_hash_table_info (frozen_graph )
372
396
placeholder_to_table_info = {}
373
- for r in imported .__dict__ .values ():
374
- if isinstance (r , TfRestoredResourceType ) and hasattr (r , '_create_resource' ) and hasattr (r , 'resource_handle' ):
375
- # Add tables from saved_model table initializers
376
- # pylint: disable=protected-access
377
- initializer = r ._create_resource .concrete_functions [0 ].function_def
378
- new_names , new_k_dtypes , new_v_dtypes = get_hash_table_info (initializer .node_def )
379
- table_names .extend (new_names )
380
- key_dtypes .extend (new_k_dtypes )
381
- value_dtypes .extend (new_v_dtypes )
382
- table_handle = id (r .resource_handle )
383
- if table_handle in removed_resource_to_placeholder and len (new_names ) == 1 :
384
- table_info = (new_names [0 ], new_k_dtypes [0 ], new_v_dtypes [0 ])
385
- placeholder_to_table_info [removed_resource_to_placeholder [table_handle ]] = table_info
397
+ _get_hash_table_info_from_trackable (imported , table_names , key_dtypes , value_dtypes ,
398
+ removed_resource_to_placeholder , placeholder_to_table_info )
386
399
387
400
initialized_tables = {}
388
401
for n , k_dtype , val_dtype in zip (table_names , key_dtypes , value_dtypes ):
@@ -393,6 +406,10 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
393
406
except Exception : # pylint: disable=broad-except
394
407
logger .warning ("Could not initialize table with shared_name = %r" , n )
395
408
409
+ for placeholder in removed_resource_to_placeholder .values ():
410
+ if placeholder not in placeholder_to_table_info :
411
+ logger .error ("Could not find table resource to replace placeholder %s" , placeholder )
412
+
396
413
replace_placeholders_with_tables (frozen_graph , placeholder_to_table_info )
397
414
398
415
return frozen_graph , inputs , outputs , concrete_func , imported , initialized_tables
0 commit comments