@@ -47,13 +47,12 @@ def not_implemented_tf_placeholder(*args, **kwargs):
47
47
try :
48
48
# pylint: disable=protected-access
49
49
from tensorflow .python .saved_model .load import _RestoredResource as TfRestoredResourceType
50
+ from tensorflow .python .ops .lookup_ops import StaticHashTable as TfStaticHashTableType
51
+ from tensorflow .python .training .tracking .base import Trackable as TfTrackableType
50
52
except ImportError :
51
53
TfRestoredResourceType = tuple () # isinstance(x, tuple()) is always false
52
-
53
- try :
54
- from tensorflow .python .training .tracking .tracking import AutoTrackable as TfAutoTrackableType
55
- except ImportError :
56
- TfAutoTrackableType = tuple ()
54
+ TfStaticHashTableType = tuple ()
55
+ TfTrackableType = tuple ()
57
56
58
57
if is_tf2 ():
59
58
convert_variables_to_constants = tf .compat .v1 .graph_util .convert_variables_to_constants
@@ -147,6 +146,46 @@ def fix_freezing_errors(graph_def):
147
146
return graph_def
148
147
149
148
149
+ def from_trackable (trackable , concrete_func , inputs , outputs , large_model ):
150
+ err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."
151
+
152
+ # Avoid errors due to bug in TF freezing
153
+ removed_resource_to_placeholder , graph_captures_copy , func_captures_copy = \
154
+ _remove_non_variable_resources_from_captures (concrete_func )
155
+
156
+ try :
157
+ frozen_graph = from_function (concrete_func , inputs , outputs , large_model )
158
+ except ValueError as e :
159
+ if any (msg in str (e ) for msg in ["exceeds maximum protobuf size of 2GB" , "string too long" ]):
160
+ raise ValueError (err_large_model )
161
+ raise e
162
+
163
+ # We might be returning the concrete_func so let's put it back in working order
164
+ _restore_captured_resources (concrete_func , graph_captures_copy , func_captures_copy )
165
+
166
+ table_names , key_dtypes , value_dtypes = get_hash_table_info (frozen_graph )
167
+ placeholder_to_table_info = {}
168
+ _get_hash_table_info_from_trackable (trackable , table_names , key_dtypes , value_dtypes ,
169
+ removed_resource_to_placeholder , placeholder_to_table_info )
170
+
171
+ initialized_tables = {}
172
+ for n , k_dtype , val_dtype in zip (table_names , key_dtypes , value_dtypes ):
173
+ h = lookup_ops .hash_table_v2 (k_dtype , val_dtype , shared_name = n )
174
+ try :
175
+ k , v = lookup_ops .lookup_table_export_v2 (h , k_dtype , val_dtype )
176
+ initialized_tables [n ] = (k .numpy (), v .numpy ())
177
+ except Exception : # pylint: disable=broad-except
178
+ logger .warning ("Could not initialize table with shared_name = %r" , n )
179
+
180
+ for placeholder in removed_resource_to_placeholder .values ():
181
+ if placeholder not in placeholder_to_table_info :
182
+ logger .error ("Could not find table resource to replace placeholder %s" , placeholder )
183
+
184
+ replace_placeholders_with_tables (frozen_graph , placeholder_to_table_info )
185
+
186
+ return frozen_graph , initialized_tables
187
+
188
+
150
189
def from_function (func , input_names , output_names , large_model = False ):
151
190
if large_model :
152
191
return convert_variables_to_constants_large_model (func )
@@ -327,7 +366,27 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
327
366
def _get_hash_table_info_from_trackable (trackable , table_names , key_dtypes , value_dtypes ,
328
367
removed_resource_to_placeholder , placeholder_to_table_info ):
329
368
# pylint: disable=protected-access
330
- for r in trackable .__dict__ .values ():
369
+ stack = [trackable ]
370
+ visited = set ()
371
+ while stack :
372
+ r = stack .pop ()
373
+ visited .add (id (r ))
374
+ try :
375
+ for trackable_ref in r ._checkpoint_dependencies :
376
+ if id (trackable_ref .ref ) not in visited :
377
+ if isinstance (trackable_ref .ref , TfTrackableType ):
378
+ stack .append (trackable_ref .ref )
379
+ except Exception : # pylint: disable=broad-except
380
+ continue
381
+ for t in r .__dict__ .values () if hasattr (r , '__dict__' ) else []:
382
+ if isinstance (t , TfStaticHashTableType ) and hasattr (t , '_shared_name' ):
383
+ table_names .append (t ._shared_name .encode ())
384
+ key_dtypes .append (t .key_dtype .as_datatype_enum )
385
+ value_dtypes .append (t .value_dtype .as_datatype_enum )
386
+ table_handle = id (t .resource_handle )
387
+ if table_handle in removed_resource_to_placeholder :
388
+ table_info = (table_names [- 1 ], key_dtypes [- 1 ], value_dtypes [- 1 ])
389
+ placeholder_to_table_info [removed_resource_to_placeholder [table_handle ]] = table_info
331
390
if isinstance (r , TfRestoredResourceType ) and hasattr (r , '_create_resource' ):
332
391
try :
333
392
table_handle = id (r .resource_handle )
@@ -341,9 +400,6 @@ def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, valu
341
400
if table_handle in removed_resource_to_placeholder and len (new_names ) == 1 :
342
401
table_info = (new_names [0 ], new_k_dtypes [0 ], new_v_dtypes [0 ])
343
402
placeholder_to_table_info [removed_resource_to_placeholder [table_handle ]] = table_info
344
- if isinstance (r , TfAutoTrackableType ):
345
- _get_hash_table_info_from_trackable (r , table_names , key_dtypes , value_dtypes ,
346
- removed_resource_to_placeholder , placeholder_to_table_info )
347
403
348
404
349
405
def _remove_non_variable_resources_from_captures (concrete_func ):
@@ -399,7 +455,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
399
455
err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
400
456
err_no_sig = "No signatures found in model. Try --concrete_function instead."
401
457
err_sig_nomatch = "Specified signature not in model %s"
402
- err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag."
403
458
404
459
if tag is None :
405
460
tag = ['serve' ]
@@ -456,39 +511,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
456
511
outputs = output_names
457
512
logger .info ("Outputs not left as None; will use provided names not structured output names." )
458
513
459
- # Avoid errors due to bug in TF freezing
460
- removed_resource_to_placeholder , graph_captures_copy , func_captures_copy = \
461
- _remove_non_variable_resources_from_captures (concrete_func )
462
-
463
- try :
464
- frozen_graph = from_function (concrete_func , inputs , outputs , large_model )
465
- except ValueError as e :
466
- if any (msg in str (e ) for msg in ["exceeds maximum protobuf size of 2GB" , "string too long" ]):
467
- raise ValueError (err_large_model )
468
- raise e
469
-
470
- # We might be returning the concrete_func so let's put it back in working order
471
- _restore_captured_resources (concrete_func , graph_captures_copy , func_captures_copy )
472
-
473
- table_names , key_dtypes , value_dtypes = get_hash_table_info (frozen_graph )
474
- placeholder_to_table_info = {}
475
- _get_hash_table_info_from_trackable (imported , table_names , key_dtypes , value_dtypes ,
476
- removed_resource_to_placeholder , placeholder_to_table_info )
477
-
478
- initialized_tables = {}
479
- for n , k_dtype , val_dtype in zip (table_names , key_dtypes , value_dtypes ):
480
- h = lookup_ops .hash_table_v2 (k_dtype , val_dtype , shared_name = n )
481
- try :
482
- k , v = lookup_ops .lookup_table_export_v2 (h , k_dtype , val_dtype )
483
- initialized_tables [n ] = (k .numpy (), v .numpy ())
484
- except Exception : # pylint: disable=broad-except
485
- logger .warning ("Could not initialize table with shared_name = %r" , n )
486
-
487
- for placeholder in removed_resource_to_placeholder .values ():
488
- if placeholder not in placeholder_to_table_info :
489
- logger .error ("Could not find table resource to replace placeholder %s" , placeholder )
490
-
491
- replace_placeholders_with_tables (frozen_graph , placeholder_to_table_info )
514
+ frozen_graph , initialized_tables = from_trackable (imported , concrete_func , inputs , outputs , large_model )
492
515
493
516
return frozen_graph , inputs , outputs , concrete_func , imported , initialized_tables , tensors_to_rename
494
517
0 commit comments