diff --git a/data/img001.bmp b/data/img001.bmp new file mode 100644 index 000000000..d149d76f1 Binary files /dev/null and b/data/img001.bmp differ diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs index ac9cbc60d..41ef52967 100644 --- a/src/TensorFlowNET.Core/APIs/tf.image.cs +++ b/src/TensorFlowNET.Core/APIs/tf.image.cs @@ -339,6 +339,13 @@ public Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, name: name, expand_animations: expand_animations); + public Tensor encode_png(Tensor contents, string name = null) + => image_ops_impl.encode_png(contents, name: name); + + public Tensor encode_jpeg(Tensor contents, string name = null) + => image_ops_impl.encode_jpeg(contents, name: name); + + /// /// Convenience function to check if the 'contents' encodes a JPEG image. /// diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs index be1e86e6c..ea1e44b28 100644 --- a/src/TensorFlowNET.Core/APIs/tf.io.cs +++ b/src/TensorFlowNET.Core/APIs/tf.io.cs @@ -16,6 +16,7 @@ limitations under the License. using System.Collections.Generic; using Tensorflow.IO; +using Tensorflow.Operations; namespace Tensorflow { @@ -46,6 +47,12 @@ public Operation save_v2(Tensor prefix, string[] tensor_names, public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); + + public Operation write_file(string filename, Tensor conentes, string name = null) + => write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name); + + public Operation write_file(Tensor filename, Tensor conentes, string name = null) + => gen_ops.write_file(filename, conentes, name); } public GFile gfile = new GFile(); diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index 59d5fd030..2bdd65f5b 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -80,6 +80,11 @@ BackwardFunction GetGradientFunction(string op_name, Tensor[] op_outputs) => (out_grads, unneeded_gradients) => { + if(!ops.gradientFunctions.ContainsKey(op_name)) + { + throw new Exception($"gradientFunctions not find op_name: {op_name}"); + } + if (ops.gradientFunctions[op_name] == null) return new Tensor[op_inputs.Length]; diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index a43a91b9a..87646a9ea 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -229,6 +229,37 @@ public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads) }; } + /// + /// Gradient function for Conv2D. + /// + /// + /// + /// + [RegisterGradient("DepthwiseConv2dNative")] + public static Tensor[] _DepthwiseConv2DGrad(Operation op, Tensor[] grads) + { + var dilations = op.get_attr_list("dilations"); + var strides = op.get_attr_list("strides"); + var padding = op.get_attr("padding"); + var explicit_paddings = op.get_attr_list("explicit_paddings"); + var data_format = op.get_attr("data_format"); + var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); + + return new Tensor[] + { + gen_nn_ops.depthwise_conv2d_native_backprop_input( + shape[0], op.inputs[1], grads[0], + strides, padding, explicit_paddings, + dilations: dilations, + data_format: data_format), + gen_nn_ops.depthwise_conv2d_native_backprop_filter(op.inputs[0], shape[1], grads[0], + strides, padding, + dilations: dilations, + explicit_paddings: explicit_paddings, + data_format: data_format) + }; + } + [RegisterGradient("FusedBatchNorm")] public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) => _BaseFusedBatchNormGrad(op, 0, grads); diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index 1840f88b9..889c76d91 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -24,6 +24,7 @@ ICallback fit(NDArray x, NDArray y, List callbacks = null, float validation_split = 0f, ValidationDataPack validation_data = null, + int validation_step = 10, bool shuffle = true, Dictionary class_weight = null, NDArray sample_weight = null, @@ -47,6 +48,20 @@ ICallback fit(IEnumerable x, NDArray y, int workers = 1, bool use_multiprocessing = false); + public ICallback fit(IDatasetV2 dataset, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + List callbacks = null, + IDatasetV2 validation_data = null, + int validation_step = 10, // 间隔多少次会进行一次验证 + bool shuffle = true, + Dictionary class_weight = null, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + void save(string filepath, bool overwrite = true, bool include_optimizer = true, @@ -85,6 +100,14 @@ Tensors predict(Tensors x, int workers = 1, bool use_multiprocessing = false); + public Tensors predict(IDatasetV2 dataset, + int batch_size = -1, + int verbose = 0, + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + void summary(int line_length = -1, float[] positions = null); IKerasConfig get_config(); diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 5e08eadc4..3fd98e7a8 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -55,6 +55,12 @@ public ILayer Conv1D(int filters, string kernel_initializer = "glorot_uniform", string bias_initializer = "zeros"); + public ILayer Conv2D(int filters, + Shape kernel_size = null, + Shape strides = null, + string padding = "valid" + ); + public ILayer Conv2D(int filters, Shape kernel_size = null, Shape strides = null, @@ -95,6 +101,19 @@ public ILayer Conv2D(int filters, bool use_bias = true, string kernel_initializer = "glorot_uniform", string bias_initializer = "zeros"); + public ILayer DepthwiseConv2D(Shape kernel_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null, + int groups = 1, + int depth_multiplier = 1, + string activation = null, + bool use_bias = false, + string kernel_initializer = "glorot_uniform", + string bias_initializer = "zeros", + string depthwise_initializer = "glorot_uniform" + ); public ILayer Dense(int units); public ILayer Dense(int units, diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index 318b8b142..f1aff28ee 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -102,7 +102,10 @@ internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_stat { throw new ValueError("\'image\' must be fully defined."); } - var dims = image_shape["-3:"]; + var dims = new Shape(new[] { + image_shape.dims[image_shape.dims.Length - 3], + image_shape.dims[image_shape.dims.Length - 2], + image_shape.dims[image_shape.dims.Length - 1]}); foreach (var dim in dims.dims) { if (dim == 0) @@ -112,16 +115,18 @@ internal static Operation[] _CheckAtLeast3DImage(Tensor image, bool require_stat } var image_shape_last_three_elements = new Shape(new[] { - image_shape.dims[image_shape.dims.Length - 1], + image_shape.dims[image_shape.dims.Length - 3], image_shape.dims[image_shape.dims.Length - 2], - image_shape.dims[image_shape.dims.Length - 3]}); + image_shape.dims[image_shape.dims.Length - 1]}); if (!image_shape_last_three_elements.IsFullyDefined) { Tensor image_shape_ = array_ops.shape(image); - var image_shape_return = tf.constant(new[] { - image_shape_.dims[image_shape.dims.Length - 1], - image_shape_.dims[image_shape.dims.Length - 2], - image_shape_.dims[image_shape.dims.Length - 3]}); + var image_shape_return = tf.slice(image_shape_, new[] { Math.Max(image_shape.dims.Length - 3, 0) }, new[] { 3 }); + + //var image_shape_return = tf.constant(new[] { + // image_shape_.dims[image_shape_.dims.Length - 3], + // image_shape_.dims[image_shape_.dims.Length - 2], + // image_shape_.dims[image_shape_.dims.Length - 1]}); return new Operation[] { check_ops.assert_positive( @@ -209,10 +214,10 @@ internal static Tensor _random_flip(Tensor image, int flip_index, int seed, stri } public static Tensor flip_left_right(Tensor image) - => _flip(image, 0, "flip_left_right"); + => _flip(image, 1, "flip_left_right"); public static Tensor flip_up_down(Tensor image) - => _flip(image, 1, "flip_up_down"); + => _flip(image, 0, "flip_up_down"); internal static Tensor _flip(Tensor image, int flip_index, string scope_name) { @@ -223,11 +228,11 @@ internal static Tensor _flip(Tensor image, int flip_index, string scope_name) Shape shape = image.shape; if (shape.ndim == 3 || shape.ndim == Unknown) { - return fix_image_flip_shape(image, gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index }))); + return fix_image_flip_shape(image, gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new int[] { flip_index }))); } else if (shape.ndim == 4) { - return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { (flip_index + 1) % 2 })); + return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { flip_index + 1 })); } else { @@ -2047,6 +2052,22 @@ internal static (Tensor, Tensor) non_max_suppression_padded_v1(Tensor boxes, Ten }); } + public static Tensor encode_jpeg(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "encode_jpeg"), scope => + { + return gen_ops.encode_jpeg(contents, name:name); + }); + } + + public static Tensor encode_png(Tensor contents, string name = null) + { + return tf_with(ops.name_scope(name, "encode_png"), scope => + { + return gen_ops.encode_png(contents, name: name); + }); + } + public static Tensor is_jpeg(Tensor contents, string name = null) { return tf_with(ops.name_scope(name, "is_jpeg"), scope => diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index e65c4850d..f688d4d5d 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -249,6 +249,9 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T case sbyte val: tensor_proto.IntVal.AddRange(new[] { (int)val }); break; + case byte val: + tensor_proto.IntVal.AddRange(new[] { (int)val }); + break; case int val: tensor_proto.IntVal.AddRange(new[] { val }); break; @@ -262,7 +265,7 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T tensor_proto.DoubleVal.AddRange(new[] { val }); break; default: - throw new Exception("make_tensor_proto Not Implemented"); + throw new Exception($"make_tensor_proto Not Implemented {values.GetType().Name}"); } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 94a2e6646..474d5e5a5 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -132,6 +132,7 @@ Dictionary evaluate(DataHandler data_handler, CallbackList callba var end_step = step + data_handler.StepIncrement; if (!is_val) callbacks.on_test_batch_end(end_step, logs); + GC.Collect(); } } callbacks.on_test_end(logs); @@ -167,7 +168,9 @@ Dictionary test_step_multi_inputs_function(DataHandler data_handl Dictionary test_step(DataHandler data_handler, Tensors x, Tensors y) { (x,y) = data_handler.DataAdapter.Expand1d(x, y); + var y_pred = Apply(x, training: false); + var loss = compiled_loss.Call(y, y_pred); compiled_metrics.update_state(y, y_pred); return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 689fc9fb8..d61211c71 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -41,6 +41,7 @@ public ICallback fit(NDArray x, NDArray y, List callbacks = null, float validation_split = 0f, ValidationDataPack validation_data = null, + int validation_step = 10, bool shuffle = true, Dictionary class_weight = null, NDArray sample_weight = null, @@ -147,7 +148,7 @@ public ICallback fit(IEnumerable x, NDArray y, } } - public History fit(IDatasetV2 dataset, + public ICallback fit(IDatasetV2 dataset, int batch_size = -1, int epochs = 1, int verbose = 1, @@ -156,7 +157,6 @@ public History fit(IDatasetV2 dataset, int validation_step = 10, bool shuffle = true, Dictionary class_weight = null, - NDArray sample_weight = null, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, @@ -170,7 +170,7 @@ public History fit(IDatasetV2 dataset, InitialEpoch = initial_epoch, Epochs = epochs, Shuffle = shuffle, - SampleWeight = sample_weight, + ClassWeight = class_weight, MaxQueueSize = max_queue_size, Workers = workers, UseMultiprocessing = use_multiprocessing, @@ -218,6 +218,7 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i var end_step = step + data_handler.StepIncrement; End_step = end_step; callbacks.on_train_batch_end(end_step, logs); + GC.Collect(); } if (validation_data != null) @@ -233,11 +234,10 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i callbacks.on_train_batch_end(End_step, logs); } + GC.Collect(); callbacks.on_epoch_end(epoch, logs); - GC.Collect(); - GC.WaitForPendingFinalizers(); if (stop_training) { break; @@ -282,6 +282,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List { { "outputs", batch_outputs } }); + GC.Collect(); } } diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs b/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs new file mode 100644 index 000000000..dae4a4036 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs @@ -0,0 +1,167 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Utils; +using Tensorflow.Operations; +using Newtonsoft.Json; +using System.Security.Cryptography; + +namespace Tensorflow.Keras.Layers +{ + public class DepthwiseConv2DArgs: Conv2DArgs + { + /// + /// depth_multiplier: The number of depthwise convolution output channels for + /// each input channel.The total number of depthwise convolution output + /// channels will be equal to `filters_in* depth_multiplier`. + /// + [JsonProperty("depth_multiplier")] + public int DepthMultiplier { get; set; } = 1; + + [JsonProperty("depthwise_initializer")] + public IInitializer DepthwiseInitializer { get; set; } + } + + public class DepthwiseConv2D : Conv2D + { + /// + /// depth_multiplier: The number of depthwise convolution output channels for + /// each input channel.The total number of depthwise convolution output + /// channels will be equal to `filters_in* depth_multiplier`. + /// + int DepthMultiplier = 1; + + IInitializer DepthwiseInitializer; + + int[] strides; + + int[] dilation_rate; + + string getDataFormat() + { + return data_format == "channels_first" ? "NCHW" : "NHWC"; + } + + static int _id = 1; + + public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args) + { + args.Padding = args.Padding.ToUpper(); + + if(string.IsNullOrEmpty(args.Name)) + name = "DepthwiseConv2D_" + _id; + + this.DepthMultiplier = args.DepthMultiplier; + this.DepthwiseInitializer = args.DepthwiseInitializer; + + } + + public override void build(KerasShapesWrapper input_shape) + { + //base.build(input_shape); + + var shape = input_shape.ToSingleShape(); + + int channel_axis = data_format == "channels_first" ? 1 : -1; + var input_channel = channel_axis < 0 ? + shape.dims[shape.ndim + channel_axis] : + shape.dims[channel_axis]; + + var arg = args as DepthwiseConv2DArgs; + + if (arg.Strides.ndim != shape.ndim) + { + if (arg.Strides.ndim == 2) + { + this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 }; + } + else + { + this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides"); + } + } + else + { + this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray(); + } + + if (arg.DilationRate.ndim != shape.ndim) + { + this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate"); + } + + long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1]; + + var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] { + channel_data, + this.DepthMultiplier + }); + + this.kernel = this.add_weight( + shape: depthwise_kernel_shape, + initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer, + name: "depthwise_kernel", + trainable: true, + dtype: DType, + regularizer: this.kernel_regularizer + ); + + var axes = new Dictionary(); + axes.Add(-1, (int)input_channel); + inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes); + + + if (use_bias) + { + bias = add_weight(name: "bias", + shape: ((int)channel_data), + initializer: bias_initializer, + trainable: true, + dtype: DType); + } + + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, + bool? training = false, IOptionalArgs? optional_args = null) + { + Tensor outputs = null; + + outputs = gen_nn_ops.depthwise_conv2d_native( + inputs, + filter: this.kernel.AsTensor(), + strides: this.strides, + padding: this.padding, + dilations: this.dilation_rate, + data_format: this.getDataFormat(), + name: name + ); + + if (use_bias) + { + if (data_format == "channels_first") + { + throw new NotImplementedException("call channels_first"); + } + else + { + outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias), + data_format: this.getDataFormat(), name: name); + } + } + + if (activation != null) + outputs = activation.Apply(outputs); + + + return outputs; + } + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 928e7e337..bcc19dc22 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -112,7 +112,28 @@ public ILayer Conv1D(int filters, KernelInitializer = GetInitializerByName(kernel_initializer), BiasInitializer = GetInitializerByName(bias_initializer) }); - + public ILayer Conv2D(int filters, + Shape kernel_size = null, + Shape strides = null, + string padding = "valid") + => new Conv2D(new Conv2DArgs + { + Rank = 2, + Filters = filters, + KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, + Strides = strides == null ? (1, 1) : strides, + Padding = padding, + DataFormat = null, + DilationRate = (1, 1), + Groups = 1, + UseBias = false, + KernelRegularizer = null, + KernelInitializer =tf.glorot_uniform_initializer, + BiasInitializer = tf.zeros_initializer, + BiasRegularizer = null, + ActivityRegularizer = null, + Activation = keras.activations.Linear, + }); /// /// 2D convolution layer (e.g. spatial convolution over images). /// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. @@ -210,6 +231,38 @@ public ILayer Conv2D(int filters, Activation = keras.activations.GetActivationFromName(activation) }); + public ILayer DepthwiseConv2D(Shape kernel_size = null, + Shape strides = null, + string padding = "valid", + string data_format = null, + Shape dilation_rate = null, + int groups = 1, + int depth_multiplier = 1, + string activation = null, + bool use_bias = false, + string kernel_initializer = "glorot_uniform", + string bias_initializer = "zeros", + string depthwise_initializer = "glorot_uniform" + ) + => new DepthwiseConv2D(new DepthwiseConv2DArgs + { + Rank = 2, + Filters = 1, + KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, + Strides = strides == null ? (1) : strides, + Padding = padding, + DepthMultiplier = depth_multiplier, + DataFormat = data_format, + DilationRate = dilation_rate == null ? (1) : dilation_rate, + Groups = groups, + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + DepthwiseInitializer = GetInitializerByName(depthwise_initializer == null ? kernel_initializer : depthwise_initializer), + BiasInitializer = GetInitializerByName(bias_initializer), + Activation = keras.activations.GetActivationFromName(activation), + }); + + /// /// Transposed convolution layer (sometimes called Deconvolution). /// diff --git a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs index d671b6096..127b65bf6 100644 --- a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs @@ -4,6 +4,7 @@ using Tensorflow; using static Tensorflow.Binding; using System; +using System.IO; namespace TensorFlowNET.UnitTest { @@ -164,5 +165,94 @@ public void TestCropAndResize() Assert.AreEqual(result.size, 16ul); Assert.AreEqual(result[0, 0, 0, 0], 12f); } + + [TestMethod] + public void ImageSaveTest() + { + var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); + var jpegImgPath = TestHelper.GetFullPathFromDataDir("img001.jpeg"); + var pngImgPath = TestHelper.GetFullPathFromDataDir("img001.png"); + + File.Delete(jpegImgPath); + File.Delete(pngImgPath); + + var contents = tf.io.read_file(imgPath); + var bmp = tf.image.decode_image(contents); + Assert.AreEqual(bmp.name, "decode_image/DecodeImage:0"); + + var jpeg = tf.image.encode_jpeg(bmp); + var op1 = tf.io.write_file(jpegImgPath, jpeg); + + var png = tf.image.encode_png(bmp); + var op2 = tf.io.write_file(pngImgPath, png); + + this.session().run(op1); + this.session().run(op2); + + Assert.IsTrue(File.Exists(jpegImgPath), "not find file:" + jpegImgPath); + Assert.IsTrue(File.Exists(pngImgPath), "not find file:" + pngImgPath); + + // 如果要测试图片正确性,需要注释下面两行代码 + File.Delete(jpegImgPath); + File.Delete(pngImgPath); + } + + [TestMethod] + public void ImageFlipTest() + { + var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); + + var contents = tf.io.read_file(imgPath); + var bmp = tf.image.decode_image(contents); + + // 左右翻转 + var lrImgPath = TestHelper.GetFullPathFromDataDir("img001_lr.png"); + File.Delete(lrImgPath); + + var lr = tf.image.flip_left_right(bmp); + var png = tf.image.encode_png(lr); + var op = tf.io.write_file(lrImgPath, png); + this.session().run(op); + + Assert.IsTrue(File.Exists(lrImgPath), "not find file:" + lrImgPath); + + // 上下翻转 + var updownImgPath = TestHelper.GetFullPathFromDataDir("img001_updown.png"); + File.Delete(updownImgPath); + + var updown = tf.image.flip_up_down(bmp); + var pngupdown = tf.image.encode_png(updown); + var op2 = tf.io.write_file(updownImgPath, pngupdown); + this.session().run(op2); + Assert.IsTrue(File.Exists(updownImgPath)); + + + // 暂时先人工观测图片是否翻转,观测时需要删除下面这两行代码 + File.Delete(lrImgPath); + File.Delete(updownImgPath); + + // 多图翻转 + // 目前直接通过 bmp 拿到 shape ,这里先用默认定义图片大小来构建了 + var mImg = tf.stack(new[] { bmp, lr }, axis:0); + print(mImg.shape); + + var up2 = tf.image.flip_up_down(mImg); + + var updownImgPath_m1 = TestHelper.GetFullPathFromDataDir("img001_m_ud.png"); // 直接上下翻转 + File.Delete(updownImgPath_m1); + + var img001_updown_m2 = TestHelper.GetFullPathFromDataDir("img001_m_lr_ud.png"); // 先左右再上下 + File.Delete(img001_updown_m2); + + var png2 = tf.image.encode_png(up2[0]); + tf.io.write_file(updownImgPath_m1, png2); + + png2 = tf.image.encode_png(up2[1]); + tf.io.write_file(img001_updown_m2, png2); + + // 如果要测试图片正确性,需要注释下面两行代码 + File.Delete(updownImgPath_m1); + File.Delete(img001_updown_m2); + } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs index c7eab364c..635f13a54 100644 --- a/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs +++ b/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs @@ -33,6 +33,40 @@ public bool Equal(float[] f1, float[] f2) return ret; } + + public void AssertArray(int[] f1, int[] f2) + { + bool ret = false; + for (var i = 0; i < f1.Length; i++) + { + ret = f1[i] == f2[i]; + if (!ret) + break; + } + + if (!ret) + { + Assert.Fail($"Array not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]"); + } + } + + public void AssertArray(float[] f1, float[] f2) + { + bool ret = false; + var tolerance = .00001f; + for (var i = 0; i < f1.Length; i++) + { + ret = Math.Abs(f1[i] - f2[i]) <= tolerance; + if (!ret) + break; + } + + if (!ret) + { + Assert.Fail($"Array float not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]"); + } + } + public bool Equal(double[] d1, double[] d2) { bool ret = false; diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs index 997dcb4f6..15c6e80fe 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs @@ -1,6 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; using Tensorflow.NumPy; using static Tensorflow.KerasApi; +using static Tensorflow.Binding; namespace Tensorflow.Keras.UnitTest.Layers { @@ -193,5 +195,128 @@ public void BasicConv2D_ksize_dilation_same() Assert.AreEqual(x.dims[2], y.shape[2]); Assert.AreEqual(filters, y.shape[3]); } + + + [TestMethod] + public void BasicDepthwiseConv2D() + { + var conv = keras.layers.DepthwiseConv2D(kernel_size:3, strides:1, activation: null, + padding:"same", depthwise_initializer: "ones"); + + var x = np.arange(2 * 9* 9* 3).reshape((2, 9, 9, 3)); + var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); + + var y = conv.Apply(x2); + + print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); + + + Assert.AreEqual(4, y.shape.ndim); + var arr = y.numpy().reshape((2, 9, 9, 3)); + + AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 }); + AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 2457f, 2466f, 2475f }); + + var bn = keras.layers.BatchNormalization(); + var y2 = bn.Apply(y); + arr = y2.numpy().ToArray(); + + double delta = 0.0001; // 误差范围 + + Assert.AreEqual(arr[0], 59.97002f, delta); + Assert.AreEqual(arr[1], 63.96802f, delta); + } + + + [TestMethod] + public void BasicDepthwiseConv2D_strides_2() + { + var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: (1, 2, 2, 1), activation: null, + padding: "same", depthwise_initializer: "ones"); + + var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3)); + var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); + + var y = conv.Apply(x2); + + print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); + + Assert.AreEqual(4, y.shape.ndim); + var arr = y.numpy().reshape((2, 5, 5, 3)); + + AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 }); + AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 2727f, 2736f, 2745f }); + + var bn = keras.layers.BatchNormalization(); + var y2 = bn.Apply(y); + arr = y2.numpy().ToArray(); + + double delta = 0.0001; // 误差范围 + + Assert.AreEqual(arr[0], 59.97002f, delta); + Assert.AreEqual(arr[1], 63.96802f, delta); + } + + + + [TestMethod] + public void BasicDepthwiseConv2D_strides_3() + { + var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 3, activation: null, + padding: "same", depthwise_initializer: "ones"); + + var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3)); + var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); + + var y = conv.Apply(x2); + + print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); + + Assert.AreEqual(4, y.shape.ndim); + var arr = y.numpy().reshape((2, 3, 3, 3)); + + AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 }); + AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 3267f, 3276f, 3285f }); + + var bn = keras.layers.BatchNormalization(); + var y2 = bn.Apply(y); + arr = y2.numpy().ToArray(); + + double delta = 0.0001; // 误差范围 + + Assert.AreEqual(arr[0], 269.86508f, delta); + Assert.AreEqual(arr[1], 278.8606f, delta); + + } + [TestMethod] + public void BasicDepthwiseConv2D_UseBias() + { + var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 1, activation: null, + use_bias: true, padding: "same", + depthwise_initializer: "ones", + bias_initializer:"ones" + ); + + var weight = conv.get_weights(); + + var x = np.arange(9 * 9 * 3).reshape((1, 9, 9, 3)); + var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); + var y = conv.Apply(x2); + + Assert.AreEqual(4, y.shape.ndim); + var arr = y.numpy().ToArray(); + + Assert.AreEqual(arr[0], 61f); + Assert.AreEqual(arr[1], 65f); + + var bn = keras.layers.BatchNormalization(); + var y2 = bn.Apply(y); + arr = y2.numpy().ToArray(); + + double delta = 0.0001; // 误差范围 + + Assert.AreEqual(arr[0], 60.96952f, delta); + Assert.AreEqual(arr[1], 64.96752f, delta); + } } } diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs index d08f4e505..b7b9ae128 100644 --- a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs +++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs @@ -20,6 +20,20 @@ public bool Equal(float f1, float f2) return Math.Abs(f1 - f2) <= tolerance; } + public bool Equal(long[] l1, long[] l2) + { + if (l1.Length != l2.Length) + return false; + + for (var i = 0; i < l1.Length; i++) + { + if (l1[i] != l2[i]) + return false; + } + + return true; + } + public bool Equal(float[] f1, float[] f2) { bool ret = false; diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs index 675689bb1..e25c9779d 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs @@ -3,6 +3,7 @@ using Tensorflow; using static Tensorflow.Binding; using System.Linq; +using Tensorflow.Operations; namespace TensorFlowNET.UnitTest.ManagedAPI { @@ -105,5 +106,321 @@ public void ReverseArray() Assert.IsTrue(Equal(a[0].ToArray().Reverse().ToArray(), b[0].ToArray())); Assert.IsTrue(Equal(a[1].ToArray().Reverse().ToArray(), b[1].ToArray())); } + + [TestMethod] + public void ReverseImgArray3D() + { + // 创建 sourceImg 数组 + var sourceImgArray = new float[,,] { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }; + var sourceImg = ops.convert_to_tensor(sourceImgArray); + + // 创建 lrImg 数组 + var lrImgArray = new float[,,] { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }; + var lrImg = ops.convert_to_tensor(lrImgArray); + + var lr = tf.image.flip_left_right(sourceImg); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail."); + + var lr2 = tf.reverse(sourceImg, 1); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + // 创建 udImg 数组 + var udImgArray = new float[,,] { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }; + var udImg = ops.convert_to_tensor(udImgArray); + + var ud = tf.image.flip_up_down(sourceImg); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail."); + + var ud2 = tf.reverse(sourceImg, new Axis(0)); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=0) fail."); + + var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=0 fail."); + } + + [TestMethod] + public void ReverseImgArray4D() + { + // 原图左上角,加一张左右翻转后的图片 + var m = new float[,,,] { + { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + } + }; + var sourceImg = ops.convert_to_tensor(m); + + var lrArray = new float[,,,] { + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 }, + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 }, + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + } + }; + var lrImg = ops.convert_to_tensor(lrArray); + + // 创建 ud 数组 + var udArray = new float[,,,] { + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + } + } + }; + var udImg = ops.convert_to_tensor(udArray); + + var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + var ud2 = tf.reverse(sourceImg, new Axis(1)); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var ud = tf.image.flip_up_down(sourceImg); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail."); + + // 左右翻转 + var lr = tf.image.flip_left_right(sourceImg); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail."); + + var lr2 = tf.reverse(sourceImg, 0); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + } + + [TestMethod] + public void ReverseImgArray4D_3x3() + { + // 原图左上角,加一张左右翻转后的图片 + var m = new float[,,,] { + { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + } + }; + var sourceImg = ops.convert_to_tensor(m); + + var lrArray = new float[,,,] { + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 }, + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 }, + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + } + }; + var lrImg = ops.convert_to_tensor(lrArray); + + // 创建 ud 数组 + var udArray = new float[,,,] { + { + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 237, 28, 36 }, + { 255, 255, 255 }, + { 255, 255, 255 } + } + }, + { { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 255, 255, 255 } + }, + { + { 255, 255, 255 }, + { 255, 255, 255 }, + { 237, 28, 36 } + } + } + }; + var udImg = ops.convert_to_tensor(udArray); + + var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + var ud2 = tf.reverse(sourceImg, new Axis(1)); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var ud = tf.image.flip_up_down(sourceImg); + Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail."); + + // 左右翻转 + var lr = tf.image.flip_left_right(sourceImg); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail."); + + var lr2 = tf.reverse(sourceImg, 0); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail."); + + var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); + Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail."); + + } } } diff --git a/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs b/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs new file mode 100644 index 000000000..f5a8685be --- /dev/null +++ b/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs @@ -0,0 +1,44 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Linq; +using static Tensorflow.Binding; +using Tensorflow; + +namespace TensorFlowNET.UnitTest.NumPy +{ + [TestClass] + public class ShapeTest : EagerModeTestBase + { + [Ignore] + [TestMethod] + public unsafe void ShapeGetLastElements() + { + // test code from function _CheckAtLeast3DImage + // 之前的 _CheckAtLeast3DImage 有bug,现在通过测试,下面的代码是正确的 + // todo: shape["-3:"] 的写法,目前有bug,需要修复,单元测试等修复后再放开,暂时先忽略测试 + + var image_shape = new Shape(new[] { 32, 64, 3 }); + var image_shape_4d = new Shape(new[] { 4, 64, 32, 3 }); + + var image_shape_last_three_elements = new Shape(new[] { + image_shape.dims[image_shape.dims.Length - 3], + image_shape.dims[image_shape.dims.Length - 2], + image_shape.dims[image_shape.dims.Length - 1]}); + + var image_shape_last_three_elements2 = image_shape["-3:"]; + + Assert.IsTrue(Equal(image_shape_last_three_elements.dims, image_shape_last_three_elements2.dims), "3dims get fail."); + + var image_shape_last_three_elements_4d = new Shape(new[] { + image_shape_4d.dims[image_shape_4d.dims.Length - 3], + image_shape_4d.dims[image_shape_4d.dims.Length - 2], + image_shape_4d.dims[image_shape_4d.dims.Length - 1]}); + + var image_shape_last_three_elements2_4d = image_shape_4d["-3:"]; + + Assert.IsTrue(Equals(image_shape_last_three_elements_4d.dims, image_shape_last_three_elements2_4d.dims), "4dims get fail."); + } + + } +} \ No newline at end of file