@@ -211,9 +211,11 @@ public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_pref
211
211
212
212
string restore_device = string . IsNullOrEmpty ( options . experimental_io_device ) ? "cpu:0" : options . experimental_io_device ! ;
213
213
214
- // tf python has code `with ops.device(restore_device):` here.
215
- tf . device ( restore_device ) ; // may be risky.
216
- var restored_tensors = gen_ops . restore_v2 ( file_prefix , tensor_names . ToArray ( ) , slice_specs . ToArray ( ) , tensor_dtypes . ToArray ( ) ) ;
214
+ Tensor [ ] restored_tensors = null ;
215
+ tf_with ( ops . device ( restore_device ) , _ =>
216
+ {
217
+ restored_tensors = gen_ops . restore_v2 ( file_prefix , tensor_names . ToArray ( ) , slice_specs . ToArray ( ) , tensor_dtypes . ToArray ( ) ) ;
218
+ } ) ;
217
219
218
220
Dictionary < string , IDictionary < string , Tensor > > restored_tensor_dict = new ( ) ;
219
221
int idx = 0 ;
@@ -338,11 +340,14 @@ public Operation save(Tensor file_prefix, CheckpointOptions? options= null)
338
340
options = new CheckpointOptions ( ) ;
339
341
}
340
342
341
- tf . device ( "CPU" ) ; // may be risky.
342
- var sharded_suffix = array_ops . where ( gen_ops . regex_full_match ( file_prefix , tf . constant ( @"^s3://.*" ) ) ,
343
+ Tensor tmp_checkpoint_prefix = null ;
344
+ tf_with ( ops . device ( "CPU" ) , _ =>
345
+ {
346
+ var sharded_suffix = array_ops . where ( gen_ops . regex_full_match ( file_prefix , tf . constant ( @"^s3://.*" ) ) ,
343
347
constant_op . constant ( ".part" ) , constant_op . constant ( "_temp/part" ) ) ;
344
- var tmp_checkpoint_prefix = gen_ops . string_join ( new Tensor [ ] { file_prefix , sharded_suffix } ) ;
345
- IDictionary < string , Tensor > registered_paths = _registered_savers . Keys . ToDictionary ( x => x , x => registered_saver_filename ( file_prefix , x ) ) ;
348
+ tmp_checkpoint_prefix = gen_ops . string_join ( new Tensor [ ] { file_prefix , sharded_suffix } ) ;
349
+ IDictionary < string , Tensor > registered_paths = _registered_savers . Keys . ToDictionary ( x => x , x => registered_saver_filename ( file_prefix , x ) ) ;
350
+ } ) ;
346
351
347
352
Operation save_fn ( )
348
353
{
@@ -364,16 +369,24 @@ Operation save_fn()
364
369
var saver = pair . Value ;
365
370
last_device = device ;
366
371
// skip the extra process of device name because of lack of API.
367
- tf . device ( device ) ;
368
- var shard_prefix = sharded_filename ( tmp_checkpoint_prefix , shard , num_shards_tensor ) ;
372
+ Tensor shard_prefix = null ;
373
+ tf_with ( ops . device ( device ) , _ =>
374
+ {
375
+ shard_prefix = sharded_filename ( tmp_checkpoint_prefix , shard , num_shards_tensor ) ;
376
+ } ) ;
369
377
saved_prefixes . Add ( shard_prefix ) ;
370
- sharded_saves . Add ( saver . save ( shard_prefix , options ) ) ;
378
+ tf_with ( ops . device ( device ) , _ =>
379
+ {
380
+ sharded_saves . Add ( saver . save ( shard_prefix , options ) ) ;
381
+ } ) ;
371
382
}
372
383
using ( var controller = ops . control_dependencies ( sharded_saves . ToArray ( ) ) )
373
384
{
374
385
string merge_device = string . IsNullOrEmpty ( options . experimental_io_device ) ? last_device : options . experimental_io_device ;
375
- tf . device ( merge_device ) ;
376
- return gen_ops . merge_v2_checkpoints ( saved_prefixes . ToArray ( ) , tf . constant ( file_prefix ) , delete_old_dirs : true ) ;
386
+ return tf_with ( ops . device ( merge_device ) , _ =>
387
+ {
388
+ return gen_ops . merge_v2_checkpoints ( saved_prefixes . ToArray ( ) , tf . constant ( file_prefix ) , delete_old_dirs : true ) ;
389
+ } ) ;
377
390
}
378
391
}
379
392
@@ -407,54 +420,56 @@ IDictionary<string, Operation> restore_func()
407
420
{
408
421
var device = single_saver . Key ;
409
422
var saver = single_saver . Value ;
410
- tf . device ( device ) ;
411
- var restored_tensor_dict = saver . restore ( file_prefix , options ) ;
412
-
413
- foreach ( var pair in restored_tensor_dict )
423
+ tf_with ( ops . device ( device ) , _ =>
414
424
{
415
- var checkpoint_key = pair . Key ;
416
- var slice_and_tensor = pair . Value ;
417
- foreach ( var item in slice_and_tensor )
425
+ var restored_tensor_dict = saver . restore ( file_prefix , options ) ;
426
+
427
+ foreach ( var pair in restored_tensor_dict )
418
428
{
419
- var slice_spec = item . Key ;
420
- var tensor = item . Value ;
421
- var restore_fn = _keys_to_restore_fn [ ( checkpoint_key , slice_spec ) ] ;
422
- var internal_dict = restore_fn_inputs . SetDefault ( restore_fn , new Dictionary < string , Maybe < Tensor , IDictionary < string , Tensor > > > ( ) ) ;
423
- if ( ! string . IsNullOrEmpty ( slice_spec ) )
429
+ var checkpoint_key = pair . Key ;
430
+ var slice_and_tensor = pair . Value ;
431
+ foreach ( var item in slice_and_tensor )
424
432
{
425
- if ( ! internal_dict . ContainsKey ( checkpoint_key ) )
433
+ var slice_spec = item . Key ;
434
+ var tensor = item . Value ;
435
+ var restore_fn = _keys_to_restore_fn [ ( checkpoint_key , slice_spec ) ] ;
436
+ var internal_dict = restore_fn_inputs . SetDefault ( restore_fn , new Dictionary < string , Maybe < Tensor , IDictionary < string , Tensor > > > ( ) ) ;
437
+ if ( ! string . IsNullOrEmpty ( slice_spec ) )
426
438
{
427
- Dictionary < string , Tensor > dict = new ( ) ;
428
- dict [ slice_spec ] = tensor ;
429
- internal_dict [ checkpoint_key ] = new Maybe < Tensor , IDictionary < string , Tensor > > ( dict ) ;
439
+ if ( ! internal_dict . ContainsKey ( checkpoint_key ) )
440
+ {
441
+ Dictionary < string , Tensor > dict = new ( ) ;
442
+ dict [ slice_spec ] = tensor ;
443
+ internal_dict [ checkpoint_key ] = new Maybe < Tensor , IDictionary < string , Tensor > > ( dict ) ;
444
+ }
445
+ else
446
+ {
447
+ internal_dict [ checkpoint_key ] . GetValue < IDictionary < string , Tensor > > ( ) [ slice_spec ] = tensor ;
448
+ }
430
449
}
431
450
else
432
451
{
433
- internal_dict [ checkpoint_key ] . GetValue < IDictionary < string , Tensor > > ( ) [ slice_spec ] = tensor ;
452
+ internal_dict [ checkpoint_key ] = new Maybe < Tensor , IDictionary < string , Tensor > > ( tensor ) ;
434
453
}
435
- }
436
- else
437
- {
438
- internal_dict [ checkpoint_key ] = new Maybe < Tensor , IDictionary < string , Tensor > > ( tensor ) ;
439
- }
440
- restore_fn_input_count [ restore_fn ] -- ;
454
+ restore_fn_input_count [ restore_fn ] -- ;
441
455
442
- if ( restore_fn_input_count [ restore_fn ] == 0 )
443
- {
444
- Dictionary < string , Maybe < Tensor , IDictionary < string , Tensor > > > restored_tensors = new ( ) ;
445
- foreach ( var input in restore_fn_inputs [ restore_fn ] )
456
+ if ( restore_fn_input_count [ restore_fn ] == 0 )
446
457
{
447
- restored_tensors [ TrackableUtils . extract_local_name ( input . Key ) ] = input . Value ;
448
- }
449
- var ret = restore_fn . DynamicInvoke ( restored_tensors ) ;
450
- if ( ret is IDictionary < string , Operation > )
451
- {
452
- var dict = ( IDictionary < string , Operation > ) ret ;
453
- restore_ops = restore_ops . Concat ( dict ) . ToDictionary ( x => x . Key , x => x . Value ) ;
458
+ Dictionary < string , Maybe < Tensor , IDictionary < string , Tensor > > > restored_tensors = new ( ) ;
459
+ foreach ( var input in restore_fn_inputs [ restore_fn ] )
460
+ {
461
+ restored_tensors [ TrackableUtils . extract_local_name ( input . Key ) ] = input . Value ;
462
+ }
463
+ var ret = restore_fn . DynamicInvoke ( restored_tensors ) ;
464
+ if ( ret is IDictionary < string , Operation > )
465
+ {
466
+ var dict = ( IDictionary < string , Operation > ) ret ;
467
+ restore_ops = restore_ops . Concat ( dict ) . ToDictionary ( x => x . Key , x => x . Value ) ;
468
+ }
454
469
}
455
470
}
456
471
}
457
- }
472
+ } ) ;
458
473
}
459
474
460
475
foreach ( var item in _registered_savers )
@@ -500,21 +515,25 @@ public SaverDef to_proto()
500
515
private Tensor _traced_save ( Tensor file_prefix )
501
516
{
502
517
var save_op = save ( file_prefix ) ;
503
- tf . device ( "cpu:0" ) ;
504
- using ( ops . control_dependencies ( new object [ ] { save_op } ) )
518
+ return tf_with ( ops . device ( "cpu:0" ) , _ =>
505
519
{
506
- return array_ops . identity ( file_prefix ) ;
507
- }
520
+ return tf_with ( ops . control_dependencies ( new object [ ] { save_op } ) , __ =>
521
+ {
522
+ return array_ops . identity ( file_prefix ) ;
523
+ } ) ;
524
+ } ) ;
508
525
}
509
526
510
527
private Tensor _traced_restore ( Tensor file_prefix )
511
528
{
512
529
var restore_op = restore ( file_prefix ) ;
513
- tf . device ( "cpu:0" ) ;
514
- using ( ops . control_dependencies ( restore_op . Values . ToArray ( ) ) )
530
+ return tf_with ( ops . device ( "cpu:0" ) , _ =>
515
531
{
516
- return array_ops . identity ( file_prefix ) ;
517
- }
532
+ return tf_with ( ops . control_dependencies ( restore_op . Values . ToArray ( ) ) , __ =>
533
+ {
534
+ return array_ops . identity ( file_prefix ) ;
535
+ } ) ;
536
+ } ) ;
518
537
}
519
538
520
539
public static MultiDeviceSaver from_saveables ( IEnumerable < MySaveableObject > saveables , IDictionary < string , IDictionary < string , Trackable > > ? registered_savers = null , bool call_with_mapped_captures = false )
0 commit comments