2020
2121import json
2222import os
23+ import shutil
24+ import tempfile
25+ from zipfile import ZipFile
2326
27+ # Required to load saved models that use TFDF.
28+ import tensorflow_decision_forests
2429import tensorflow as tf
2530from tensorflow .core .framework import graph_pb2
2631from tensorflow .core .framework import node_def_pb2
2732from tensorflow .core .protobuf import config_pb2
2833from tensorflow .core .protobuf import device_properties_pb2
2934from tensorflow .core .protobuf import meta_graph_pb2
35+ from tensorflow .io import gfile
3036from tensorflow .python .checkpoint .trackable_view import TrackableView
3137from tensorflow .python .eager import context
3238from tensorflow .python .framework import convert_to_constants
@@ -399,7 +405,7 @@ def write_artifacts(topology,
399405 assert isinstance (weights_manifest , list )
400406 model_json [common .ARTIFACT_WEIGHTS_MANIFEST_KEY ] = weights_manifest
401407
402- with tf . io . gfile .GFile (output_graph , 'w' ) as f :
408+ with gfile .GFile (output_graph , 'w' ) as f :
403409 json .dump (model_json , f )
404410
405411def _remove_unused_control_flow_inputs (input_graph_def ):
@@ -421,6 +427,49 @@ def _check_signature_in_model(saved_model, signature_name):
421427 "are available: %s" % (signature_name ,
422428 saved_model .signatures .keys ()))
423429
430+ def _copy_assets (saved_model_dir , output_dir ):
431+ input_assets_path = os .path .join (saved_model_dir , common .ASSETS_DIRECTORY_NAME )
432+
433+ if gfile .exists (input_assets_path ) and gfile .isdir (input_assets_path ):
434+
435+ tmp_dir = tempfile .mkdtemp ()
436+ zip_path = gfile .join (tmp_dir , common .ASSETS_DIRECTORY_NAME + '.zip' )
437+
438+ with ZipFile (zip_path , 'w' ) as archive :
439+ for (input_dir_path , _ , file_names ) in gfile .walk (input_assets_path ):
440+
441+ relative_dir_path = os .path .relpath (input_dir_path , input_assets_path )
442+
443+ for file_name in file_names :
444+
445+ input_file_path = gfile .join (input_dir_path , file_name )
446+ relative_file_path = gfile .join (relative_dir_path , file_name )
447+
448+ with gfile .GFile (input_file_path , 'rb' ) as input_file :
449+ with archive .open (relative_file_path , 'w' ) as relative_file :
450+ shutil .copyfileobj (input_file , relative_file )
451+
452+ output_assets_path = gfile .join (output_dir , common .ASSETS_DIRECTORY_NAME + '.zip' )
453+ gfile .copy (zip_path , output_assets_path , overwrite = True )
454+
455+ if gfile .isdir (tmp_dir ):
456+ gfile .rmtree (tmp_dir )
457+
458+ # TFDF stores the necessary files for its binary in the assets folder.
459+ ASSET_REQUIRING_OPS = set ([
460+ 'SimpleMLCreateModelResource'
461+ 'SimpleMLLoadModelFromPathWithHandle' ,
462+ 'SimpleMLInferenceOpWithHandle' ,
463+ ])
464+
465+ def _is_assets_required (model_ops ):
466+ return not ASSET_REQUIRING_OPS .isdisjoint (model_ops )
467+
468+ def _get_frozen_graph_ops (frozen_graph ):
469+ if frozen_graph is None :
470+ return []
471+ return [node .op for node in frozen_graph .as_graph_def ().node ]
472+
424473
425474def _freeze_saved_model_v1 (saved_model_dir , saved_model_tags ,
426475 output_node_names ):
@@ -745,8 +794,8 @@ def _convert_tf_saved_model(output_dir,
745794 if signature_def is None :
746795 signature_def = 'serving_default'
747796
748- if not tf . io . gfile .exists (output_dir ):
749- tf . io . gfile .makedirs (output_dir )
797+ if not gfile .exists (output_dir ):
798+ gfile .makedirs (output_dir )
750799 output_graph = os .path .join (
751800 output_dir , common .ARTIFACT_MODEL_JSON_FILE_NAME )
752801
@@ -852,6 +901,12 @@ def _convert_tf_saved_model(output_dir,
852901 # tensorflow version.
853902 tf_version = tf .__version__
854903
904+ if saved_model_dir :
905+ model_ops = set (_get_frozen_graph_ops (frozen_graph )) | \
906+ set (_get_frozen_graph_ops (frozen_initializer_graph ))
907+ if _is_assets_required (model_ops ):
908+ _copy_assets (saved_model_dir , output_dir )
909+
855910 optimize_graph (frozen_graph , signature ,
856911 output_graph , tf_version ,
857912 quantization_dtype_map = quantization_dtype_map ,
@@ -1137,7 +1192,7 @@ def convert_tf_hub_module(module_handle, output_dir,
11371192 # TODO(vbardiovskyg): We can remove this v1 code path once loading of all v1
11381193 # modules is fixed on the TF side, or once the modules we cannot load become
11391194 # replaced with newer versions.
1140- if tf . io . gfile .exists (os .path .join (module_path , _HUB_V1_MODULE_PB )):
1195+ if gfile .exists (os .path .join (module_path , _HUB_V1_MODULE_PB )):
11411196 print ("Loading the module using TF 1.X interface from %s." % module_path )
11421197 convert_tf_hub_module_v1 (module_path , output_dir , signature ,
11431198 quantization_dtype_map ,
0 commit comments