13
13
using Tensorflow . Variables ;
14
14
using Tensorflow . Functions ;
15
15
using Tensorflow . Training . Saving . SavedModel ;
16
+ using Tensorflow . Trackables ;
16
17
17
18
namespace Tensorflow
18
19
{
@@ -51,9 +52,13 @@ public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto,
51
52
_node_filters = filters ;
52
53
_node_path_to_id = _convert_node_paths_to_ints ( ) ;
53
54
_loaded_nodes = new Dictionary < int , ( Trackable , Action < object , object , object > ) > ( ) ;
54
- foreach ( var filter in filters )
55
+
56
+ if ( filters != null )
55
57
{
56
- _loaded_nodes [ _node_path_to_id [ filter . Key ] ] = filter . Value ;
58
+ foreach ( var filter in filters )
59
+ {
60
+ _loaded_nodes [ _node_path_to_id [ filter . Key ] ] = filter . Value ;
61
+ }
57
62
}
58
63
59
64
_filtered_nodes = _retrieve_all_filtered_nodes ( ) ;
@@ -535,7 +540,13 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
535
540
dependencies [ item . Key ] = nodes [ item . Value ] ;
536
541
}
537
542
538
- return _recreate_default ( proto , node_id , dependencies ) ;
543
+ return proto . KindCase switch
544
+ {
545
+ SavedObject . KindOneofCase . Resource => RestoredResource . deserialize_from_proto ( ) ,
546
+ SavedObject . KindOneofCase . Asset => Asset . deserialize_from_proto ( ) ,
547
+ SavedObject . KindOneofCase . Constant => TrackableConstant . deserialize_from_proto ( ) ,
548
+ _ => _recreate_default ( proto , node_id , dependencies )
549
+ } ;
539
550
}
540
551
541
552
/// <summary>
@@ -549,7 +560,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
549
560
return proto . KindCase switch
550
561
{
551
562
SavedObject . KindOneofCase . UserObject => _recreate_user_object ( proto . UserObject , node_id ) ,
552
- SavedObject . KindOneofCase . Function => throw new NotImplementedException ( ) ,
563
+ SavedObject . KindOneofCase . Function => _recreate_function ( proto . Function , null ) ,
553
564
SavedObject . KindOneofCase . BareConcreteFunction => throw new NotImplementedException ( ) ,
554
565
SavedObject . KindOneofCase . Variable => _recreate_variable ( proto . Variable ) ,
555
566
SavedObject . KindOneofCase . CapturedTensor => throw new NotImplementedException ( )
@@ -609,6 +620,13 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
609
620
}
610
621
}
611
622
623
+ private ( ConcreteFunction , Action < object , object , object > ) _recreate_function ( SavedFunction proto ,
624
+ Dictionary < Maybe < string , int > , Trackable > dependencies )
625
+ {
626
+ throw new NotImplementedException ( ) ;
627
+ //var fn = function_deserialization.setup_bare_concrete_function(proto, )
628
+ }
629
+
612
630
private ( ConcreteFunction , Action < object , object , object > ) _recreate_bare_concrete_function ( SavedBareConcreteFunction proto ,
613
631
Dictionary < Maybe < string , int > , Trackable > dependencies )
614
632
{
0 commit comments