diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs
index 0681258e4..8635f6620 100644
--- a/src/TensorFlowNET.Core/APIs/tf.init.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.init.cs
@@ -76,13 +76,13 @@ public IInitializer random_normal_initializer(float mean = 0.0f,
///
///
public IInitializer variance_scaling_initializer(float factor = 1.0f,
- string mode = "FAN_IN",
- bool uniform = false,
+ string mode = "fan_in",
+ string distribution = "truncated_normal",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling(
- factor: factor,
+ scale: factor,
mode: mode,
- uniform: uniform,
+ distribution: distribution,
seed: seed,
dtype: dtype);
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
index a0724630c..f34c63d1b 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
@@ -6,34 +6,34 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class ConvolutionalArgs : AutoSerializeLayerArgs
{
- public int Rank { get; set; } = 2;
+ public int Rank { get; set; }
[JsonProperty("filters")]
public int Filters { get; set; }
public int NumSpatialDims { get; set; } = Unknown;
[JsonProperty("kernel_size")]
- public Shape KernelSize { get; set; } = 5;
+ public Shape KernelSize { get; set; }
///
/// specifying the stride length of the convolution.
///
[JsonProperty("strides")]
- public Shape Strides { get; set; } = (1, 1);
+ public Shape Strides { get; set; }
[JsonProperty("padding")]
- public string Padding { get; set; } = "valid";
+ public string Padding { get; set; }
[JsonProperty("data_format")]
public string DataFormat { get; set; }
[JsonProperty("dilation_rate")]
- public Shape DilationRate { get; set; } = (1, 1);
+ public Shape DilationRate { get; set; }
[JsonProperty("groups")]
- public int Groups { get; set; } = 1;
+ public int Groups { get; set; }
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; }
[JsonProperty("kernel_initializer")]
- public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
+ public IInitializer KernelInitializer { get; set; }
[JsonProperty("bias_initializer")]
- public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
+ public IInitializer BiasInitializer { get; set; }
[JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; }
[JsonProperty("bias_regularizer")]
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedIInitializerJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedIInitializerJsonConverter.cs
new file mode 100644
index 000000000..0ff245180
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedIInitializerJsonConverter.cs
@@ -0,0 +1,68 @@
+using Newtonsoft.Json.Linq;
+using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Operations;
+using Tensorflow.Operations.Initializers;
+
+namespace Tensorflow.Keras.Common
+{
+ class InitializerInfo
+ {
+ public string class_name { get; set; }
+ public JObject config { get; set; }
+ }
+ public class CustomizedIinitializerJsonConverter : JsonConverter
+ {
+ public override bool CanConvert(Type objectType)
+ {
+ return objectType == typeof(IInitializer);
+ }
+
+ public override bool CanRead => true;
+
+ public override bool CanWrite => true;
+
+ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
+ {
+ var initializer = value as IInitializer;
+ if(initializer is null)
+ {
+ JToken.FromObject(null).WriteTo(writer);
+ return;
+ }
+ JToken.FromObject(new InitializerInfo()
+ {
+ class_name = initializer.ClassName,
+ config = JObject.FromObject(initializer.Config)
+ }, serializer).WriteTo(writer);
+ }
+
+ public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
+ {
+ var info = serializer.Deserialize(reader);
+ if(info is null)
+ {
+ return null;
+ }
+ return info.class_name switch
+ {
+ "Constant" => new Constant(info.config["value"].ToObject()),
+ "GlorotUniform" => new GlorotUniform(seed: info.config["seed"].ToObject()),
+ "Ones" => new Ones(),
+ "Orthogonal" => new Orthogonal(info.config["gain"].ToObject(), info.config["seed"].ToObject()),
+ "RandomNormal" => new RandomNormal(info.config["mean"].ToObject(), info.config["stddev"].ToObject(),
+ info.config["seed"].ToObject()),
+ "RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject(),
+ maxval:info.config["maxval"].ToObject(), seed: info.config["seed"].ToObject()),
+ "TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject(), info.config["stddev"].ToObject(),
+ info.config["seed"].ToObject()),
+ "VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject(), info.config["mode"].ToObject(),
+ info.config["distribution"].ToObject(), info.config["seed"].ToObject()),
+ "Zeros" => new Zeros(),
+ _ => throw new ValueError($"The specified initializer {info.class_name} cannot be recognized.")
+ };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
index 722e0a75e..9d4b53a99 100644
--- a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
@@ -60,12 +60,20 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
- var shape_info_from_python = serializer.Deserialize(reader);
- if (shape_info_from_python is null)
+ long?[] dims;
+ try
{
- return null;
+ var shape_info_from_python = serializer.Deserialize(reader);
+ if (shape_info_from_python is null)
+ {
+ return null;
+ }
+ dims = shape_info_from_python.items;
+ }
+ catch(JsonSerializationException)
+ {
+ dims = serializer.Deserialize(reader);
}
- long ?[]dims = shape_info_from_python.items;
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
{
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
index def1cb7a0..7cd88cc68 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
@@ -26,12 +26,12 @@ public class GlorotUniform : VarianceScaling
public override IDictionary Config => _config;
public GlorotUniform(float scale = 1.0f,
- string mode = "FAN_AVG",
- bool uniform = true,
+ string mode = "fan_avg",
+ string distribution = "uniform",
int? seed = null,
- TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale,
+ TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale: scale,
mode: mode,
- uniform: uniform,
+ distribution: distribution,
seed: seed,
dtype: dtype)
{
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
index 9748b1004..ca8348aa6 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
@@ -16,9 +16,11 @@ limitations under the License.
using Newtonsoft.Json;
using System.Collections.Generic;
+using Tensorflow.Keras.Common;
namespace Tensorflow
{
+ [JsonConverter(typeof(CustomizedIinitializerJsonConverter))]
public interface IInitializer
{
[JsonProperty("class_name")]
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
index f104e8e83..37fdd764c 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
@@ -28,35 +28,42 @@ public class VarianceScaling : IInitializer
{
protected float _scale;
protected string _mode;
- protected string _distribution;
protected int? _seed;
protected TF_DataType _dtype;
- protected bool _uniform;
+ protected string _distribution;
private readonly Dictionary _config;
public virtual string ClassName => "VarianceScaling";
public virtual IDictionary Config => _config;
- public VarianceScaling(float factor = 2.0f,
- string mode = "FAN_IN",
- bool uniform = false,
+ public VarianceScaling(float scale = 1.0f,
+ string mode = "fan_in",
+ string distribution = "truncated_normal",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT)
{
if (!dtype.is_floating())
throw new TypeError("Cannot create initializer for non-floating point type.");
- if (!new string[] { "FAN_IN", "FAN_OUT", "FAN_AVG" }.Contains(mode))
- throw new TypeError($"Unknown {mode} %s [FAN_IN, FAN_OUT, FAN_AVG]");
+ if (!new string[] { "fan_in", "fan_out", "fan_avg" }.Contains(mode))
+ throw new TypeError($"Unknown {mode} %s [fan_in, fan_out, fan_avg]");
+ if(distribution == "normal")
+ {
+ distribution = "truncated_normal";
+ }
+ if(!new string[] { "uniform", "truncated_normal", "untruncated_normal" }.Contains(distribution))
+ {
+ throw new ValueError($"Invalid `distribution` argument: {distribution}");
+ }
- if (factor < 0)
+ if (scale <= 0)
throw new ValueError("`scale` must be positive float.");
- _scale = factor;
+ _scale = scale;
_mode = mode;
_seed = seed;
_dtype = dtype;
- _uniform = uniform;
+ _distribution = distribution;
_config = new();
_config["scale"] = _scale;
@@ -72,23 +79,28 @@ public Tensor Apply(InitializerArgs args)
float n = 0;
var (fan_in, fan_out) = _compute_fans(args.Shape);
- if (_mode == "FAN_IN")
- n = fan_in;
- else if (_mode == "FAN_OUT")
- n = fan_out;
- else if (_mode == "FAN_AVG")
- n = (fan_in + fan_out) / 2.0f;
+ var scale = this._scale;
+ if (_mode == "fan_in")
+ scale /= Math.Max(1.0f, fan_in);
+ else if (_mode == "fan_out")
+ scale /= Math.Max(1.0f, fan_out);
+ else
+ scale /= Math.Max(1.0f, (fan_in + fan_out) / 2);
- if (_uniform)
+ if(_distribution == "truncated_normal")
{
- var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n));
- return random_ops.random_uniform(args.Shape, -limit, limit, args.DType);
+ var stddev = Math.Sqrt(scale) / .87962566103423978f;
+ return random_ops.truncated_normal(args.Shape, 0.0f, (float)stddev, args.DType);
+ }
+ else if(_distribution == "untruncated_normal")
+ {
+ var stddev = Math.Sqrt(scale);
+ return random_ops.random_normal(args.Shape, 0.0f, (float)stddev, args.DType);
}
else
{
- var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n));
- return random_ops.truncated_normal(args.Shape, 0.0f, trunc_stddev, args.DType,
- seed: _seed);
+ var limit = (float)Math.Sqrt(scale * 3.0f);
+ return random_ops.random_uniform(args.Shape, -limit, limit, args.DType);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs
index ba59b3675..fe67c2b84 100644
--- a/src/TensorFlowNET.Core/Operations/gen_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs
@@ -29543,7 +29543,7 @@ public static (Tensor e, Tensor v) self_adjoint_eig_v2(Tensor input, bool? compu
/// if < 0, scale * features otherwise.
///
/// To be used together with
- /// initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN').
+ /// initializer = tf.variance_scaling_initializer(scale=1.0, mode='fan_in').
/// For correct dropout, use tf.contrib.nn.alpha_dropout.
///
/// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
diff --git a/src/TensorFlowNET.Keras/InitializersApi.cs b/src/TensorFlowNET.Keras/InitializersApi.cs
index 6bade1720..d6dfa51be 100644
--- a/src/TensorFlowNET.Keras/InitializersApi.cs
+++ b/src/TensorFlowNET.Keras/InitializersApi.cs
@@ -27,7 +27,7 @@ public partial class InitializersApi : IInitializersApi
///
public IInitializer HeNormal(int? seed = null)
{
- return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed);
+ return new VarianceScaling(scale: 2.0f, mode: "fan_in", seed: seed);
}
public IInitializer Orthogonal(float gain = 1.0f, int? seed = null)
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs
index d62b33a58..3ee61253c 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs
@@ -20,9 +20,46 @@ namespace Tensorflow.Keras.Layers
{
public class Conv1D : Convolutional
{
- public Conv1D(Conv1DArgs args) : base(args)
+ public Conv1D(Conv1DArgs args) : base(InitializeUndefinedArgs(args))
{
}
+
+ private static Conv1DArgs InitializeUndefinedArgs(Conv1DArgs args)
+ {
+ if(args.Rank == 0)
+ {
+ args.Rank = 1;
+ }
+ if(args.Strides is null)
+ {
+ args.Strides = 1;
+ }
+ if (string.IsNullOrEmpty(args.Padding))
+ {
+ args.Padding = "valid";
+ }
+ if (string.IsNullOrEmpty(args.DataFormat))
+ {
+ args.DataFormat = "channels_last";
+ }
+ if(args.DilationRate == 0)
+ {
+ args.DilationRate = 1;
+ }
+ if(args.Groups == 0)
+ {
+ args.Groups = 1;
+ }
+ if(args.KernelInitializer is null)
+ {
+ args.KernelInitializer = tf.glorot_uniform_initializer;
+ }
+ if(args.BiasInitializer is null)
+ {
+ args.BiasInitializer = tf.zeros_initializer;
+ }
+ return args;
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs
index c5c210152..a6963e307 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs
@@ -20,9 +20,42 @@ namespace Tensorflow.Keras.Layers
{
public class Conv2D : Convolutional
{
- public Conv2D(Conv2DArgs args) : base(args)
+ public Conv2D(Conv2DArgs args) : base(InitializeUndefinedArgs(args))
{
}
+
+ private static Conv2DArgs InitializeUndefinedArgs(Conv2DArgs args)
+ {
+ if(args.Rank == 0)
+ {
+ args.Rank = 2;
+ }
+ if (args.Strides is null)
+ {
+ args.Strides = (1, 1);
+ }
+ if (string.IsNullOrEmpty(args.Padding))
+ {
+ args.Padding = "valid";
+ }
+ if (args.DilationRate == 0)
+ {
+ args.DilationRate = (1, 1);
+ }
+ if (args.Groups == 0)
+ {
+ args.Groups = 1;
+ }
+ if (args.KernelInitializer is null)
+ {
+ args.KernelInitializer = tf.glorot_uniform_initializer;
+ }
+ if (args.BiasInitializer is null)
+ {
+ args.BiasInitializer = tf.zeros_initializer;
+ }
+ return args;
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
index 7b281b28e..de4080b05 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
@@ -24,11 +24,40 @@ namespace Tensorflow.Keras.Layers
{
public class Conv2DTranspose : Conv2D
{
- public Conv2DTranspose(Conv2DArgs args) : base(args)
+ public Conv2DTranspose(Conv2DArgs args) : base(InitializeUndefinedArgs(args))
{
}
+ private static Conv2DArgs InitializeUndefinedArgs(Conv2DArgs args)
+ {
+ if (args.Strides is null)
+ {
+ args.Strides = (1, 1);
+ }
+ if (string.IsNullOrEmpty(args.Padding))
+ {
+ args.Padding = "valid";
+ }
+ if (args.DilationRate == 0)
+ {
+ args.DilationRate = (1, 1);
+ }
+ if (args.Groups == 0)
+ {
+ args.Groups = 1;
+ }
+ if (args.KernelInitializer is null)
+ {
+ args.KernelInitializer = tf.glorot_uniform_initializer;
+ }
+ if (args.BiasInitializer is null)
+ {
+ args.BiasInitializer = tf.zeros_initializer;
+ }
+ return args;
+ }
+
public override void build(Shape input_shape)
{
if (len(input_shape) != 4)