Skip to content

Commit 090dc1e

Browse files
authored
Merge pull request #1190 from dogvane/master
解决keras模式下,使用GPU训练时会爆显存的bug。
2 parents 43c3705 + baf620a commit 090dc1e

File tree

20 files changed

+983
-20
lines changed

20 files changed

+983
-20
lines changed

data/img001.bmp

174 KB
Binary file not shown.

src/TensorFlowNET.Core/APIs/tf.image.cs

+7
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,13 @@ public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype
339339
=> image_ops_impl.decode_image(contents, channels: channels, dtype: dtype,
340340
name: name, expand_animations: expand_animations);
341341

342+
public Tensor encode_png(Tensor contents, string name = null)
343+
=> image_ops_impl.encode_png(contents, name: name);
344+
345+
public Tensor encode_jpeg(Tensor contents, string name = null)
346+
=> image_ops_impl.encode_jpeg(contents, name: name);
347+
348+
342349
/// <summary>
343350
/// Convenience function to check if the 'contents' encodes a JPEG image.
344351
/// </summary>

src/TensorFlowNET.Core/APIs/tf.io.cs

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System.Collections.Generic;
1818
using Tensorflow.IO;
19+
using Tensorflow.Operations;
1920

2021
namespace Tensorflow
2122
{
@@ -46,6 +47,12 @@ public Operation save_v2(Tensor prefix, string[] tensor_names,
4647
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names,
4748
string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
4849
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name);
50+
51+
public Operation write_file(string filename, Tensor conentes, string name = null)
52+
=> write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name);
53+
54+
public Operation write_file(Tensor filename, Tensor conentes, string name = null)
55+
=> gen_ops.write_file(filename, conentes, name);
4956
}
5057

5158
public GFile gfile = new GFile();

src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ BackwardFunction GetGradientFunction(string op_name,
8080
Tensor[] op_outputs)
8181
=> (out_grads, unneeded_gradients) =>
8282
{
83+
if(!ops.gradientFunctions.ContainsKey(op_name))
84+
{
85+
throw new Exception($"gradientFunctions not find op_name: {op_name}");
86+
}
87+
8388
if (ops.gradientFunctions[op_name] == null)
8489
return new Tensor[op_inputs.Length];
8590

src/TensorFlowNET.Core/Gradients/nn_grad.cs

+31
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,37 @@ public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
229229
};
230230
}
231231

232+
/// <summary>
233+
/// Gradient function for Conv2D.
234+
/// </summary>
235+
/// <param name="op"></param>
236+
/// <param name="grads"></param>
237+
/// <returns></returns>
238+
[RegisterGradient("DepthwiseConv2dNative")]
239+
public static Tensor[] _DepthwiseConv2DGrad(Operation op, Tensor[] grads)
240+
{
241+
var dilations = op.get_attr_list<int>("dilations");
242+
var strides = op.get_attr_list<int>("strides");
243+
var padding = op.get_attr<string>("padding");
244+
var explicit_paddings = op.get_attr_list<int>("explicit_paddings");
245+
var data_format = op.get_attr<string>("data_format");
246+
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
247+
248+
return new Tensor[]
249+
{
250+
gen_nn_ops.depthwise_conv2d_native_backprop_input(
251+
shape[0], op.inputs[1], grads[0],
252+
strides, padding, explicit_paddings,
253+
dilations: dilations,
254+
data_format: data_format),
255+
gen_nn_ops.depthwise_conv2d_native_backprop_filter(op.inputs[0], shape[1], grads[0],
256+
strides, padding,
257+
dilations: dilations,
258+
explicit_paddings: explicit_paddings,
259+
data_format: data_format)
260+
};
261+
}
262+
232263
[RegisterGradient("FusedBatchNorm")]
233264
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
234265
=> _BaseFusedBatchNormGrad(op, 0, grads);

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

+23
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ICallback fit(NDArray x, NDArray y,
2424
List<ICallback> callbacks = null,
2525
float validation_split = 0f,
2626
ValidationDataPack validation_data = null,
27+
int validation_step = 10,
2728
bool shuffle = true,
2829
Dictionary<int, float> class_weight = null,
2930
NDArray sample_weight = null,
@@ -47,6 +48,20 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
4748
int workers = 1,
4849
bool use_multiprocessing = false);
4950

51+
public ICallback fit(IDatasetV2 dataset,
52+
int batch_size = -1,
53+
int epochs = 1,
54+
int verbose = 1,
55+
List<ICallback> callbacks = null,
56+
IDatasetV2 validation_data = null,
57+
int validation_step = 10, // 间隔多少次会进行一次验证
58+
bool shuffle = true,
59+
Dictionary<int, float> class_weight = null,
60+
int initial_epoch = 0,
61+
int max_queue_size = 10,
62+
int workers = 1,
63+
bool use_multiprocessing = false);
64+
5065
void save(string filepath,
5166
bool overwrite = true,
5267
bool include_optimizer = true,
@@ -85,6 +100,14 @@ Tensors predict(Tensors x,
85100
int workers = 1,
86101
bool use_multiprocessing = false);
87102

103+
public Tensors predict(IDatasetV2 dataset,
104+
int batch_size = -1,
105+
int verbose = 0,
106+
int steps = -1,
107+
int max_queue_size = 10,
108+
int workers = 1,
109+
bool use_multiprocessing = false);
110+
88111
void summary(int line_length = -1, float[] positions = null);
89112

90113
IKerasConfig get_config();

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

+19
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ public ILayer Conv1D(int filters,
5555
string kernel_initializer = "glorot_uniform",
5656
string bias_initializer = "zeros");
5757

58+
public ILayer Conv2D(int filters,
59+
Shape kernel_size = null,
60+
Shape strides = null,
61+
string padding = "valid"
62+
);
63+
5864
public ILayer Conv2D(int filters,
5965
Shape kernel_size = null,
6066
Shape strides = null,
@@ -95,6 +101,19 @@ public ILayer Conv2D(int filters,
95101
bool use_bias = true,
96102
string kernel_initializer = "glorot_uniform",
97103
string bias_initializer = "zeros");
104+
public ILayer DepthwiseConv2D(Shape kernel_size = null,
105+
Shape strides = null,
106+
string padding = "valid",
107+
string data_format = null,
108+
Shape dilation_rate = null,
109+
int groups = 1,
110+
int depth_multiplier = 1,
111+
string activation = null,
112+
bool use_bias = false,
113+
string kernel_initializer = "glorot_uniform",
114+
string bias_initializer = "zeros",
115+
string depthwise_initializer = "glorot_uniform"
116+
);
98117

99118
public ILayer Dense(int units);
100119
public ILayer Dense(int units,

src/TensorFlowNET.Core/Operations/image_ops_impl.cs

+32-11
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_stat
102102
{
103103
throw new ValueError("\'image\' must be fully defined.");
104104
}
105-
var dims = image_shape["-3:"];
105+
var dims = new Shape(new[] {
106+
image_shape.dims[image_shape.dims.Length - 3],
107+
image_shape.dims[image_shape.dims.Length - 2],
108+
image_shape.dims[image_shape.dims.Length - 1]});
106109
foreach (var dim in dims.dims)
107110
{
108111
if (dim == 0)
@@ -112,16 +115,18 @@ internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_stat
112115
}
113116

114117
var image_shape_last_three_elements = new Shape(new[] {
115-
image_shape.dims[image_shape.dims.Length - 1],
118+
image_shape.dims[image_shape.dims.Length - 3],
116119
image_shape.dims[image_shape.dims.Length - 2],
117-
image_shape.dims[image_shape.dims.Length - 3]});
120+
image_shape.dims[image_shape.dims.Length - 1]});
118121
if (!image_shape_last_three_elements.IsFullyDefined)
119122
{
120123
Tensor image_shape_ = array_ops.shape(image);
121-
var image_shape_return = tf.constant(new[] {
122-
image_shape_.dims[image_shape.dims.Length - 1],
123-
image_shape_.dims[image_shape.dims.Length - 2],
124-
image_shape_.dims[image_shape.dims.Length - 3]});
124+
var image_shape_return = tf.slice(image_shape_, new[] { Math.Max(image_shape.dims.Length - 3, 0) }, new[] { 3 });
125+
126+
//var image_shape_return = tf.constant(new[] {
127+
// image_shape_.dims[image_shape_.dims.Length - 3],
128+
// image_shape_.dims[image_shape_.dims.Length - 2],
129+
// image_shape_.dims[image_shape_.dims.Length - 1]});
125130

126131
return new Operation[] {
127132
check_ops.assert_positive(
@@ -209,10 +214,10 @@ internal static Tensor _random_flip(Tensor image, int flip_index, int seed, stri
209214
}
210215

211216
public static Tensor flip_left_right(Tensor image)
212-
=> _flip(image, 0, "flip_left_right");
217+
=> _flip(image, 1, "flip_left_right");
213218

214219
public static Tensor flip_up_down(Tensor image)
215-
=> _flip(image, 1, "flip_up_down");
220+
=> _flip(image, 0, "flip_up_down");
216221

217222
internal static Tensor _flip(Tensor image, int flip_index, string scope_name)
218223
{
@@ -223,11 +228,11 @@ internal static Tensor _flip(Tensor image, int flip_index, string scope_name)
223228
Shape shape = image.shape;
224229
if (shape.ndim == 3 || shape.ndim == Unknown)
225230
{
226-
return fix_image_flip_shape(image, gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index })));
231+
return fix_image_flip_shape(image, gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new int[] { flip_index })));
227232
}
228233
else if (shape.ndim == 4)
229234
{
230-
return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { (flip_index + 1) % 2 }));
235+
return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { flip_index + 1 }));
231236
}
232237
else
233238
{
@@ -2047,6 +2052,22 @@ internal static (Tensor, Tensor) non_max_suppression_padded_v1(Tensor boxes, Ten
20472052
});
20482053
}
20492054

2055+
public static Tensor encode_jpeg(Tensor contents, string name = null)
2056+
{
2057+
return tf_with(ops.name_scope(name, "encode_jpeg"), scope =>
2058+
{
2059+
return gen_ops.encode_jpeg(contents, name:name);
2060+
});
2061+
}
2062+
2063+
public static Tensor encode_png(Tensor contents, string name = null)
2064+
{
2065+
return tf_with(ops.name_scope(name, "encode_png"), scope =>
2066+
{
2067+
return gen_ops.encode_png(contents, name: name);
2068+
});
2069+
}
2070+
20502071
public static Tensor is_jpeg(Tensor contents, string name = null)
20512072
{
20522073
return tf_with(ops.name_scope(name, "is_jpeg"), scope =>

src/TensorFlowNET.Core/Tensors/tensor_util.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
249249
case sbyte val:
250250
tensor_proto.IntVal.AddRange(new[] { (int)val });
251251
break;
252+
case byte val:
253+
tensor_proto.IntVal.AddRange(new[] { (int)val });
254+
break;
252255
case int val:
253256
tensor_proto.IntVal.AddRange(new[] { val });
254257
break;
@@ -262,7 +265,7 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
262265
tensor_proto.DoubleVal.AddRange(new[] { val });
263266
break;
264267
default:
265-
throw new Exception("make_tensor_proto Not Implemented");
268+
throw new Exception($"make_tensor_proto Not Implemented {values.GetType().Name}");
266269
}
267270
}
268271

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callba
132132
var end_step = step + data_handler.StepIncrement;
133133
if (!is_val)
134134
callbacks.on_test_batch_end(end_step, logs);
135+
GC.Collect();
135136
}
136137
}
137138
callbacks.on_test_end(logs);
@@ -167,7 +168,9 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
167168
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
168169
{
169170
(x,y) = data_handler.DataAdapter.Expand1d(x, y);
171+
170172
var y_pred = Apply(x, training: false);
173+
171174
var loss = compiled_loss.Call(y, y_pred);
172175
compiled_metrics.update_state(y, y_pred);
173176
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public ICallback fit(NDArray x, NDArray y,
4141
List<ICallback> callbacks = null,
4242
float validation_split = 0f,
4343
ValidationDataPack validation_data = null,
44+
int validation_step = 10,
4445
bool shuffle = true,
4546
Dictionary<int, float> class_weight = null,
4647
NDArray sample_weight = null,
@@ -147,7 +148,7 @@ public ICallback fit(IEnumerable<NDArray> x, NDArray y,
147148
}
148149
}
149150

150-
public History fit(IDatasetV2 dataset,
151+
public ICallback fit(IDatasetV2 dataset,
151152
int batch_size = -1,
152153
int epochs = 1,
153154
int verbose = 1,
@@ -156,7 +157,6 @@ public History fit(IDatasetV2 dataset,
156157
int validation_step = 10,
157158
bool shuffle = true,
158159
Dictionary<int, float> class_weight = null,
159-
NDArray sample_weight = null,
160160
int initial_epoch = 0,
161161
int max_queue_size = 10,
162162
int workers = 1,
@@ -170,7 +170,7 @@ public History fit(IDatasetV2 dataset,
170170
InitialEpoch = initial_epoch,
171171
Epochs = epochs,
172172
Shuffle = shuffle,
173-
SampleWeight = sample_weight,
173+
ClassWeight = class_weight,
174174
MaxQueueSize = max_queue_size,
175175
Workers = workers,
176176
UseMultiprocessing = use_multiprocessing,
@@ -218,6 +218,7 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i
218218
var end_step = step + data_handler.StepIncrement;
219219
End_step = end_step;
220220
callbacks.on_train_batch_end(end_step, logs);
221+
GC.Collect();
221222
}
222223

223224
if (validation_data != null)
@@ -233,11 +234,10 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i
233234
callbacks.on_train_batch_end(End_step, logs);
234235
}
235236

237+
GC.Collect();
236238

237239
callbacks.on_epoch_end(epoch, logs);
238240

239-
GC.Collect();
240-
GC.WaitForPendingFinalizers();
241241
if (stop_training)
242242
{
243243
break;
@@ -282,6 +282,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
282282
var end_step = step + data_handler.StepIncrement;
283283
End_step = end_step;
284284
callbacks.on_train_batch_end(end_step, logs);
285+
GC.Collect();
285286
}
286287

287288
if (validation_data != null)
@@ -301,7 +302,6 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
301302
callbacks.on_epoch_end(epoch, logs);
302303

303304
GC.Collect();
304-
GC.WaitForPendingFinalizers();
305305
if (stop_training)
306306
{
307307
break;

src/TensorFlowNET.Keras/Engine/Model.Predict.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ Tensors PredictInternal(DataHandler data_handler, int verbose)
102102
for (int i = 0; i < batch_outputs.Length; i++)
103103
batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0);
104104
}
105-
106105
var end_step = step + data_handler.StepIncrement;
107106
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
107+
GC.Collect();
108108
}
109109
}
110110

0 commit comments

Comments
 (0)