Skip to content

Fix the error when loading Conv1D layer with initialzier. #1031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/TensorFlowNET.Core/APIs/tf.init.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ public IInitializer random_normal_initializer(float mean = 0.0f,
/// <param name="dtype"></param>
/// <returns></returns>
public IInitializer variance_scaling_initializer(float factor = 1.0f,
string mode = "FAN_IN",
bool uniform = false,
string mode = "fan_in",
string distribution = "truncated_normal",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling(
factor: factor,
scale: factor,
mode: mode,
uniform: uniform,
distribution: distribution,
seed: seed,
dtype: dtype);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,34 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class ConvolutionalArgs : AutoSerializeLayerArgs
{
public int Rank { get; set; } = 2;
public int Rank { get; set; }
[JsonProperty("filters")]
public int Filters { get; set; }
public int NumSpatialDims { get; set; } = Unknown;
[JsonProperty("kernel_size")]
public Shape KernelSize { get; set; } = 5;
public Shape KernelSize { get; set; }

/// <summary>
/// specifying the stride length of the convolution.
/// </summary>
[JsonProperty("strides")]
public Shape Strides { get; set; } = (1, 1);
public Shape Strides { get; set; }
[JsonProperty("padding")]
public string Padding { get; set; } = "valid";
public string Padding { get; set; }
[JsonProperty("data_format")]
public string DataFormat { get; set; }
[JsonProperty("dilation_rate")]
public Shape DilationRate { get; set; } = (1, 1);
public Shape DilationRate { get; set; }
[JsonProperty("groups")]
public int Groups { get; set; } = 1;
public int Groups { get; set; }
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; }
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
public IInitializer KernelInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
public IInitializer BiasInitializer { get; set; }
[JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; }
[JsonProperty("bias_regularizer")]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;
using Tensorflow.Operations.Initializers;

namespace Tensorflow.Keras.Common
{
class InitializerInfo
{
public string class_name { get; set; }
public JObject config { get; set; }
}
public class CustomizedIinitializerJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(IInitializer);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var initializer = value as IInitializer;
if(initializer is null)
{
JToken.FromObject(null).WriteTo(writer);
return;
}
JToken.FromObject(new InitializerInfo()
{
class_name = initializer.ClassName,
config = JObject.FromObject(initializer.Config)
}, serializer).WriteTo(writer);
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var info = serializer.Deserialize<InitializerInfo>(reader);
if(info is null)
{
return null;
}
return info.class_name switch
{
"Constant" => new Constant<float>(info.config["value"].ToObject<float>()),
"GlorotUniform" => new GlorotUniform(seed: info.config["seed"].ToObject<int?>()),
"Ones" => new Ones(),
"Orthogonal" => new Orthogonal(info.config["gain"].ToObject<float>(), info.config["seed"].ToObject<int?>()),
"RandomNormal" => new RandomNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
info.config["seed"].ToObject<int?>()),
"RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject<float>(),
maxval:info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()),
"TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
info.config["seed"].ToObject<int?>()),
"VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject<float>(), info.config["mode"].ToObject<string>(),
info.config["distribution"].ToObject<string>(), info.config["seed"].ToObject<int?>()),
"Zeros" => new Zeros(),
_ => throw new ValueError($"The specified initializer {info.class_name} cannot be recognized.")
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,20 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
if (shape_info_from_python is null)
long?[] dims;
try
{
return null;
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
if (shape_info_from_python is null)
{
return null;
}
dims = shape_info_from_python.items;
}
catch(JsonSerializationException)
{
dims = serializer.Deserialize<long?[]>(reader);
}
long ?[]dims = shape_info_from_python.items;
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ public class GlorotUniform : VarianceScaling
public override IDictionary<string, object> Config => _config;

public GlorotUniform(float scale = 1.0f,
string mode = "FAN_AVG",
bool uniform = true,
string mode = "fan_avg",
string distribution = "uniform",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale,
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale: scale,
mode: mode,
uniform: uniform,
distribution: distribution,
seed: seed,
dtype: dtype)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ limitations under the License.

using Newtonsoft.Json;
using System.Collections.Generic;
using Tensorflow.Keras.Common;

namespace Tensorflow
{
[JsonConverter(typeof(CustomizedIinitializerJsonConverter))]
public interface IInitializer
{
[JsonProperty("class_name")]
Expand Down
56 changes: 34 additions & 22 deletions src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,42 @@ public class VarianceScaling : IInitializer
{
protected float _scale;
protected string _mode;
protected string _distribution;
protected int? _seed;
protected TF_DataType _dtype;
protected bool _uniform;
protected string _distribution;
private readonly Dictionary<string, object> _config;

public virtual string ClassName => "VarianceScaling";

public virtual IDictionary<string, object> Config => _config;

public VarianceScaling(float factor = 2.0f,
string mode = "FAN_IN",
bool uniform = false,
public VarianceScaling(float scale = 1.0f,
string mode = "fan_in",
string distribution = "truncated_normal",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT)
{
if (!dtype.is_floating())
throw new TypeError("Cannot create initializer for non-floating point type.");
if (!new string[] { "FAN_IN", "FAN_OUT", "FAN_AVG" }.Contains(mode))
throw new TypeError($"Unknown {mode} %s [FAN_IN, FAN_OUT, FAN_AVG]");
if (!new string[] { "fan_in", "fan_out", "fan_avg" }.Contains(mode))
throw new TypeError($"Unknown {mode} %s [fan_in, fan_out, fan_avg]");
if(distribution == "normal")
{
distribution = "truncated_normal";
}
if(!new string[] { "uniform", "truncated_normal", "untruncated_normal" }.Contains(distribution))
{
throw new ValueError($"Invalid `distribution` argument: {distribution}");
}

if (factor < 0)
if (scale <= 0)
throw new ValueError("`scale` must be positive float.");

_scale = factor;
_scale = scale;
_mode = mode;
_seed = seed;
_dtype = dtype;
_uniform = uniform;
_distribution = distribution;

_config = new();
_config["scale"] = _scale;
Expand All @@ -72,23 +79,28 @@ public Tensor Apply(InitializerArgs args)

float n = 0;
var (fan_in, fan_out) = _compute_fans(args.Shape);
if (_mode == "FAN_IN")
n = fan_in;
else if (_mode == "FAN_OUT")
n = fan_out;
else if (_mode == "FAN_AVG")
n = (fan_in + fan_out) / 2.0f;
var scale = this._scale;
if (_mode == "fan_in")
scale /= Math.Max(1.0f, fan_in);
else if (_mode == "fan_out")
scale /= Math.Max(1.0f, fan_out);
else
scale /= Math.Max(1.0f, (fan_in + fan_out) / 2);

if (_uniform)
if(_distribution == "truncated_normal")
{
var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n));
return random_ops.random_uniform(args.Shape, -limit, limit, args.DType);
var stddev = Math.Sqrt(scale) / .87962566103423978f;
return random_ops.truncated_normal(args.Shape, 0.0f, (float)stddev, args.DType);
}
else if(_distribution == "untruncated_normal")
{
var stddev = Math.Sqrt(scale);
return random_ops.random_normal(args.Shape, 0.0f, (float)stddev, args.DType);
}
else
{
var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n));
return random_ops.truncated_normal(args.Shape, 0.0f, trunc_stddev, args.DType,
seed: _seed);
var limit = (float)Math.Sqrt(scale * 3.0f);
return random_ops.random_uniform(args.Shape, -limit, limit, args.DType);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/gen_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29543,7 +29543,7 @@ public static (Tensor e, Tensor v) self_adjoint_eig_v2(Tensor input, bool? compu
/// if &amp;lt; 0, <c>scale * features</c> otherwise.
///
/// To be used together with
/// <c>initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')</c>.
/// <c>initializer = tf.variance_scaling_initializer(scale=1.0, mode='fan_in')</c>.
/// For correct dropout, use <c>tf.contrib.nn.alpha_dropout</c>.
///
/// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Keras/InitializersApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public partial class InitializersApi : IInitializersApi
/// <returns></returns>
public IInitializer HeNormal(int? seed = null)
{
return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed);
return new VarianceScaling(scale: 2.0f, mode: "fan_in", seed: seed);
}

public IInitializer Orthogonal(float gain = 1.0f, int? seed = null)
Expand Down
39 changes: 38 additions & 1 deletion src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,46 @@ namespace Tensorflow.Keras.Layers
{
public class Conv1D : Convolutional
{
public Conv1D(Conv1DArgs args) : base(args)
public Conv1D(Conv1DArgs args) : base(InitializeUndefinedArgs(args))
{

}

private static Conv1DArgs InitializeUndefinedArgs(Conv1DArgs args)
{
if(args.Rank == 0)
{
args.Rank = 1;
}
if(args.Strides is null)
{
args.Strides = 1;
}
if (string.IsNullOrEmpty(args.Padding))
{
args.Padding = "valid";
}
if (string.IsNullOrEmpty(args.DataFormat))
{
args.DataFormat = "channels_last";
}
if(args.DilationRate == 0)
{
args.DilationRate = 1;
}
if(args.Groups == 0)
{
args.Groups = 1;
}
if(args.KernelInitializer is null)
{
args.KernelInitializer = tf.glorot_uniform_initializer;
}
if(args.BiasInitializer is null)
{
args.BiasInitializer = tf.zeros_initializer;
}
return args;
}
}
}
35 changes: 34 additions & 1 deletion src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,42 @@ namespace Tensorflow.Keras.Layers
{
public class Conv2D : Convolutional
{
public Conv2D(Conv2DArgs args) : base(args)
public Conv2D(Conv2DArgs args) : base(InitializeUndefinedArgs(args))
{

}

private static Conv2DArgs InitializeUndefinedArgs(Conv2DArgs args)
{
if(args.Rank == 0)
{
args.Rank = 2;
}
if (args.Strides is null)
{
args.Strides = (1, 1);
}
if (string.IsNullOrEmpty(args.Padding))
{
args.Padding = "valid";
}
if (args.DilationRate == 0)
{
args.DilationRate = (1, 1);
}
if (args.Groups == 0)
{
args.Groups = 1;
}
if (args.KernelInitializer is null)
{
args.KernelInitializer = tf.glorot_uniform_initializer;
}
if (args.BiasInitializer is null)
{
args.BiasInitializer = tf.zeros_initializer;
}
return args;
}
}
}
Loading