Skip to content

Commit a075bba

Browse files
authored
Merge pull request #1023 from AsakusaRinne/fix_save_model
Fix the error when saving model with GPU.
2 parents 1aa4de6 + 0360fbb commit a075bba

15 files changed

+568
-92
lines changed

src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs

+9-4
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@ public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string,
5353
var g = to_graph.as_default();
5454
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
5555
object_map, call_with_mapped_captures, saveables_cache);
56-
tf.device("/cpu:0");
57-
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
56+
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ =>
57+
{
58+
// TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception.
59+
return constant_op.constant(graph_proto.ToByteArray());
60+
});
5861
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
5962
g.Exit();
6063
return (named_saveable_objects, registered_savers);
@@ -65,8 +68,10 @@ public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string,
6568
{
6669
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
6770
object_map, call_with_mapped_captures, saveables_cache);
68-
tf.device("/cpu:0");
69-
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
71+
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ =>
72+
{
73+
return constant_op.constant(graph_proto.ToString());
74+
});
7075
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
7176
return (named_saveable_objects, registered_savers);
7277
}

src/TensorFlowNET.Core/Checkpoint/checkpoint.cs

+13-7
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ public TrackableSaver(ObjectGraphView graph_view)
5858

5959
if(object_graph_tensor is null)
6060
{
61-
tf.device("/cpu:0");
62-
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
61+
tf_with(ops.device("/cpu:0"), _ =>
62+
{
63+
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
64+
});
6365
}
6466
else
6567
{
@@ -230,22 +232,26 @@ public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
230232
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);
231233

232234
Dictionary<Tensor, string> file_prefix_feed_dict;
233-
Tensor file_prefix_tensor;
235+
Tensor file_prefix_tensor = null;
234236
if (graph_building)
235237
{
236238
if(_file_prefix_placeholder is null)
237239
{
238-
tf.device("/cpu:0");
239-
_file_prefix_placeholder = constant_op.constant("model");
240+
_file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ =>
241+
{
242+
return constant_op.constant("model");
243+
});
240244
}
241245
file_prefix_tensor = _file_prefix_placeholder;
242246
file_prefix_feed_dict = new();
243247
file_prefix_feed_dict[_file_prefix_placeholder] = save_path;
244248
}
245249
else
246250
{
247-
tf.device("/cpu:0");
248-
file_prefix_tensor = constant_op.constant(save_path);
251+
file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ =>
252+
{
253+
return constant_op.constant(save_path);
254+
});
249255
file_prefix_feed_dict = null;
250256
}
251257
TrackableObjectGraph object_graph_proto = new();

src/TensorFlowNET.Core/Checkpoint/functional_saver.cs

+74-55
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,11 @@ public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_pref
211211

212212
string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!;
213213

214-
// tf python has code `with ops.device(restore_device):` here.
215-
tf.device(restore_device); // may be risky.
216-
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
214+
Tensor[] restored_tensors = null;
215+
tf_with(ops.device(restore_device), _ =>
216+
{
217+
restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
218+
});
217219

218220
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
219221
int idx = 0;
@@ -338,11 +340,14 @@ public Operation save(Tensor file_prefix, CheckpointOptions? options= null)
338340
options = new CheckpointOptions();
339341
}
340342

341-
tf.device("CPU"); // may be risky.
342-
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
343+
Tensor tmp_checkpoint_prefix = null;
344+
tf_with(ops.device("CPU"), _ =>
345+
{
346+
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
343347
constant_op.constant(".part"), constant_op.constant("_temp/part"));
344-
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
345-
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));
348+
tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
349+
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));
350+
});
346351

347352
Operation save_fn()
348353
{
@@ -364,16 +369,24 @@ Operation save_fn()
364369
var saver = pair.Value;
365370
last_device = device;
366371
// skip the extra process of device name because of lack of API.
367-
tf.device(device);
368-
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
372+
Tensor shard_prefix = null;
373+
tf_with(ops.device(device), _ =>
374+
{
375+
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
376+
});
369377
saved_prefixes.Add(shard_prefix);
370-
sharded_saves.Add(saver.save(shard_prefix, options));
378+
tf_with(ops.device(device), _ =>
379+
{
380+
sharded_saves.Add(saver.save(shard_prefix, options));
381+
});
371382
}
372383
using (var controller = ops.control_dependencies(sharded_saves.ToArray()))
373384
{
374385
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
375-
tf.device(merge_device);
376-
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
386+
return tf_with(ops.device(merge_device), _ =>
387+
{
388+
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
389+
});
377390
}
378391
}
379392

@@ -407,54 +420,56 @@ IDictionary<string, Operation> restore_func()
407420
{
408421
var device = single_saver.Key;
409422
var saver = single_saver.Value;
410-
tf.device(device);
411-
var restored_tensor_dict = saver.restore(file_prefix, options);
412-
413-
foreach(var pair in restored_tensor_dict)
423+
tf_with(ops.device(device), _ =>
414424
{
415-
var checkpoint_key = pair.Key;
416-
var slice_and_tensor = pair.Value;
417-
foreach(var item in slice_and_tensor)
425+
var restored_tensor_dict = saver.restore(file_prefix, options);
426+
427+
foreach (var pair in restored_tensor_dict)
418428
{
419-
var slice_spec = item.Key;
420-
var tensor = item.Value;
421-
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
422-
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
423-
if (!string.IsNullOrEmpty(slice_spec))
429+
var checkpoint_key = pair.Key;
430+
var slice_and_tensor = pair.Value;
431+
foreach (var item in slice_and_tensor)
424432
{
425-
if (!internal_dict.ContainsKey(checkpoint_key))
433+
var slice_spec = item.Key;
434+
var tensor = item.Value;
435+
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
436+
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
437+
if (!string.IsNullOrEmpty(slice_spec))
426438
{
427-
Dictionary<string, Tensor> dict = new();
428-
dict[slice_spec] = tensor;
429-
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
439+
if (!internal_dict.ContainsKey(checkpoint_key))
440+
{
441+
Dictionary<string, Tensor> dict = new();
442+
dict[slice_spec] = tensor;
443+
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
444+
}
445+
else
446+
{
447+
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor;
448+
}
430449
}
431450
else
432451
{
433-
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor;
452+
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
434453
}
435-
}
436-
else
437-
{
438-
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
439-
}
440-
restore_fn_input_count[restore_fn]--;
454+
restore_fn_input_count[restore_fn]--;
441455

442-
if (restore_fn_input_count[restore_fn] == 0)
443-
{
444-
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
445-
foreach(var input in restore_fn_inputs[restore_fn])
456+
if (restore_fn_input_count[restore_fn] == 0)
446457
{
447-
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
448-
}
449-
var ret = restore_fn.DynamicInvoke(restored_tensors);
450-
if(ret is IDictionary<string, Operation>)
451-
{
452-
var dict = (IDictionary<string, Operation>)ret;
453-
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
458+
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
459+
foreach (var input in restore_fn_inputs[restore_fn])
460+
{
461+
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
462+
}
463+
var ret = restore_fn.DynamicInvoke(restored_tensors);
464+
if (ret is IDictionary<string, Operation>)
465+
{
466+
var dict = (IDictionary<string, Operation>)ret;
467+
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
468+
}
454469
}
455470
}
456471
}
457-
}
472+
});
458473
}
459474

460475
foreach(var item in _registered_savers)
@@ -500,21 +515,25 @@ public SaverDef to_proto()
500515
private Tensor _traced_save(Tensor file_prefix)
501516
{
502517
var save_op = save(file_prefix);
503-
tf.device("cpu:0");
504-
using (ops.control_dependencies(new object[]{ save_op }))
518+
return tf_with(ops.device("cpu:0"), _ =>
505519
{
506-
return array_ops.identity(file_prefix);
507-
}
520+
return tf_with(ops.control_dependencies(new object[] { save_op }), __ =>
521+
{
522+
return array_ops.identity(file_prefix);
523+
});
524+
});
508525
}
509526

510527
private Tensor _traced_restore(Tensor file_prefix)
511528
{
512529
var restore_op = restore(file_prefix);
513-
tf.device("cpu:0");
514-
using (ops.control_dependencies(restore_op.Values.ToArray()))
530+
return tf_with(ops.device("cpu:0"), _ =>
515531
{
516-
return array_ops.identity(file_prefix);
517-
}
532+
return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ =>
533+
{
534+
return array_ops.identity(file_prefix);
535+
});
536+
});
518537
}
519538

520539
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)

src/TensorFlowNET.Core/Contexts/Context.Device.cs

+73
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
using static Tensorflow.Binding;
2222
using Google.Protobuf;
2323
using Tensorflow.Device;
24+
using Tensorflow.Exceptions;
2425
using System.Collections.Generic;
2526

2627
namespace Tensorflow.Contexts
@@ -30,10 +31,30 @@ namespace Tensorflow.Contexts
3031
/// </summary>
3132
public sealed partial class Context
3233
{
34+
internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new();
35+
internal List<LogicalDevice> _logical_devices = null;
36+
internal List<string> _context_devices = null;
37+
3338
ContextDevicePlacementPolicy _device_policy;
3439
bool _log_device_placement;
40+
int _num_gpus;
3541
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>();
3642

43+
public string DeviceName { get; set; } = "";
44+
public DeviceSpec DeviceSpec { get; set; } = null;
45+
46+
internal List<string> Devices
47+
{
48+
get
49+
{
50+
if(_context_devices is null)
51+
{
52+
throw new AssertionError("Context must be initialized first.");
53+
}
54+
return _context_devices;
55+
}
56+
}
57+
3758
public void log_device_placement(bool enable)
3859
{
3960
if (_handle != null)
@@ -89,5 +110,57 @@ public PhysicalDevice[] list_physical_devices(string device_type = null)
89110

90111
return results.ToArray();
91112
}
113+
114+
public EagerDeviceContext device(string name)
115+
{
116+
return new EagerDeviceContext(this, name);
117+
}
118+
119+
internal void _set_device(string device_name, DeviceSpec device_spec)
120+
{
121+
DeviceSpec = device_spec;
122+
DeviceName = device_name;
123+
}
124+
125+
internal void _initialize_logical_devices()
126+
{
127+
List<LogicalDevice> logical_devices = new();
128+
List<string> context_devices = new();
129+
Status status = new();
130+
var device_list = c_api.TFE_ContextListDevices(_handle, status);
131+
status.Check(true);
132+
try
133+
{
134+
this._num_gpus = 0;
135+
string current_job = null;
136+
int current_task = -1;
137+
for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++)
138+
{
139+
var dev_name = c_api.TF_DeviceListName(device_list, i, status);
140+
status.Check(true);
141+
context_devices.Add(DeviceUtils.canonical_name(dev_name));
142+
var spec = DeviceSpec.from_string(dev_name);
143+
if(spec.Job == "localhost")
144+
{
145+
spec = spec.replace(job: null, replica: -1, task: -1);
146+
}
147+
logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType));
148+
var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status);
149+
var dev_type = c_api.StringPiece(dev_type_memory);
150+
status.Check(true);
151+
if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task)
152+
{
153+
_num_gpus++;
154+
}
155+
}
156+
}
157+
finally
158+
{
159+
_logical_devices = logical_devices;
160+
_context_devices = context_devices;
161+
}
162+
}
92163
}
164+
165+
public record class LogicalDevice(string name, string device_type);
93166
}

src/TensorFlowNET.Core/Contexts/Context.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ public sealed partial class Context
3434
public const int EAGER_MODE = 1;
3535

3636
int defaultExecutionMode = EAGER_MODE;
37-
public string DeviceName { get; set; } = "";
3837
public string ScopeName { get; set; } = "";
3938
bool initialized = false;
4039
ContextSwitchStack context_switches;
@@ -62,6 +61,8 @@ public void ensure_initialized()
6261
if (initialized)
6362
return;
6463

64+
Debug.Assert(_context_devices is null);
65+
6566
Config = MergeConfig();
6667
FunctionCallOptions.Config = Config;
6768
var config_str = Config.ToByteArray();
@@ -72,6 +73,7 @@ public void ensure_initialized()
7273
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy);
7374
_handle = c_api.TFE_NewContext(opts, status);
7475
status.Check(true);
76+
_initialize_logical_devices();
7577
initialized = true;
7678
}
7779

@@ -174,6 +176,7 @@ public void reset_context()
174176
{
175177
c_api.TFE_ContextClearCaches(_handle);
176178
}
179+
_device_parsing_cache.Clear();
177180
}
178181

179182
public static implicit operator SafeContextHandle(Context ctx)

0 commit comments

Comments
 (0)