Skip to content

Commit 5b33938

Browse files
authored
Merge pull request #1031 from AsakusaRinne/master
Fix the error when loading Conv1D layer with initialzier.
2 parents 682f52f + c1a14c7 commit 5b33938

File tree

12 files changed

+236
-47
lines changed

12 files changed

+236
-47
lines changed

src/TensorFlowNET.Core/APIs/tf.init.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ public IInitializer random_normal_initializer(float mean = 0.0f,
7676
/// <param name="dtype"></param>
7777
/// <returns></returns>
7878
public IInitializer variance_scaling_initializer(float factor = 1.0f,
79-
string mode = "FAN_IN",
80-
bool uniform = false,
79+
string mode = "fan_in",
80+
string distribution = "truncated_normal",
8181
int? seed = null,
8282
TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling(
83-
factor: factor,
83+
scale: factor,
8484
mode: mode,
85-
uniform: uniform,
85+
distribution: distribution,
8686
seed: seed,
8787
dtype: dtype);
8888

src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs

+8-8
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,34 @@ namespace Tensorflow.Keras.ArgsDefinition
66
{
77
public class ConvolutionalArgs : AutoSerializeLayerArgs
88
{
9-
public int Rank { get; set; } = 2;
9+
public int Rank { get; set; }
1010
[JsonProperty("filters")]
1111
public int Filters { get; set; }
1212
public int NumSpatialDims { get; set; } = Unknown;
1313
[JsonProperty("kernel_size")]
14-
public Shape KernelSize { get; set; } = 5;
14+
public Shape KernelSize { get; set; }
1515

1616
/// <summary>
1717
/// specifying the stride length of the convolution.
1818
/// </summary>
1919
[JsonProperty("strides")]
20-
public Shape Strides { get; set; } = (1, 1);
20+
public Shape Strides { get; set; }
2121
[JsonProperty("padding")]
22-
public string Padding { get; set; } = "valid";
22+
public string Padding { get; set; }
2323
[JsonProperty("data_format")]
2424
public string DataFormat { get; set; }
2525
[JsonProperty("dilation_rate")]
26-
public Shape DilationRate { get; set; } = (1, 1);
26+
public Shape DilationRate { get; set; }
2727
[JsonProperty("groups")]
28-
public int Groups { get; set; } = 1;
28+
public int Groups { get; set; }
2929
[JsonProperty("activation")]
3030
public Activation Activation { get; set; }
3131
[JsonProperty("use_bias")]
3232
public bool UseBias { get; set; }
3333
[JsonProperty("kernel_initializer")]
34-
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
34+
public IInitializer KernelInitializer { get; set; }
3535
[JsonProperty("bias_initializer")]
36-
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
36+
public IInitializer BiasInitializer { get; set; }
3737
[JsonProperty("kernel_regularizer")]
3838
public IRegularizer KernelRegularizer { get; set; }
3939
[JsonProperty("bias_regularizer")]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using Newtonsoft.Json.Linq;
2+
using Newtonsoft.Json;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
using Tensorflow.Operations;
7+
using Tensorflow.Operations.Initializers;
8+
9+
namespace Tensorflow.Keras.Common
10+
{
11+
class InitializerInfo
12+
{
13+
public string class_name { get; set; }
14+
public JObject config { get; set; }
15+
}
16+
public class CustomizedIinitializerJsonConverter : JsonConverter
17+
{
18+
public override bool CanConvert(Type objectType)
19+
{
20+
return objectType == typeof(IInitializer);
21+
}
22+
23+
public override bool CanRead => true;
24+
25+
public override bool CanWrite => true;
26+
27+
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
28+
{
29+
var initializer = value as IInitializer;
30+
if(initializer is null)
31+
{
32+
JToken.FromObject(null).WriteTo(writer);
33+
return;
34+
}
35+
JToken.FromObject(new InitializerInfo()
36+
{
37+
class_name = initializer.ClassName,
38+
config = JObject.FromObject(initializer.Config)
39+
}, serializer).WriteTo(writer);
40+
}
41+
42+
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
43+
{
44+
var info = serializer.Deserialize<InitializerInfo>(reader);
45+
if(info is null)
46+
{
47+
return null;
48+
}
49+
return info.class_name switch
50+
{
51+
"Constant" => new Constant<float>(info.config["value"].ToObject<float>()),
52+
"GlorotUniform" => new GlorotUniform(seed: info.config["seed"].ToObject<int?>()),
53+
"Ones" => new Ones(),
54+
"Orthogonal" => new Orthogonal(info.config["gain"].ToObject<float>(), info.config["seed"].ToObject<int?>()),
55+
"RandomNormal" => new RandomNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
56+
info.config["seed"].ToObject<int?>()),
57+
"RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject<float>(),
58+
maxval:info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()),
59+
"TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
60+
info.config["seed"].ToObject<int?>()),
61+
"VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject<float>(), info.config["mode"].ToObject<string>(),
62+
info.config["distribution"].ToObject<string>(), info.config["seed"].ToObject<int?>()),
63+
"Zeros" => new Zeros(),
64+
_ => throw new ValueError($"The specified initializer {info.class_name} cannot be recognized.")
65+
};
66+
}
67+
}
68+
}

src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs

+12-4
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,20 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
6060

6161
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
6262
{
63-
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
64-
if (shape_info_from_python is null)
63+
long?[] dims;
64+
try
6565
{
66-
return null;
66+
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
67+
if (shape_info_from_python is null)
68+
{
69+
return null;
70+
}
71+
dims = shape_info_from_python.items;
72+
}
73+
catch(JsonSerializationException)
74+
{
75+
dims = serializer.Deserialize<long?[]>(reader);
6776
}
68-
long ?[]dims = shape_info_from_python.items;
6977
long[] convertedDims = new long[dims.Length];
7078
for(int i = 0; i < dims.Length; i++)
7179
{

src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ public class GlorotUniform : VarianceScaling
2626
public override IDictionary<string, object> Config => _config;
2727

2828
public GlorotUniform(float scale = 1.0f,
29-
string mode = "FAN_AVG",
30-
bool uniform = true,
29+
string mode = "fan_avg",
30+
string distribution = "uniform",
3131
int? seed = null,
32-
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale,
32+
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale: scale,
3333
mode: mode,
34-
uniform: uniform,
34+
distribution: distribution,
3535
seed: seed,
3636
dtype: dtype)
3737
{

src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs

+2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ limitations under the License.
1616

1717
using Newtonsoft.Json;
1818
using System.Collections.Generic;
19+
using Tensorflow.Keras.Common;
1920

2021
namespace Tensorflow
2122
{
23+
[JsonConverter(typeof(CustomizedIinitializerJsonConverter))]
2224
public interface IInitializer
2325
{
2426
[JsonProperty("class_name")]

src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs

+34-22
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,42 @@ public class VarianceScaling : IInitializer
2828
{
2929
protected float _scale;
3030
protected string _mode;
31-
protected string _distribution;
3231
protected int? _seed;
3332
protected TF_DataType _dtype;
34-
protected bool _uniform;
33+
protected string _distribution;
3534
private readonly Dictionary<string, object> _config;
3635

3736
public virtual string ClassName => "VarianceScaling";
3837

3938
public virtual IDictionary<string, object> Config => _config;
4039

41-
public VarianceScaling(float factor = 2.0f,
42-
string mode = "FAN_IN",
43-
bool uniform = false,
40+
public VarianceScaling(float scale = 1.0f,
41+
string mode = "fan_in",
42+
string distribution = "truncated_normal",
4443
int? seed = null,
4544
TF_DataType dtype = TF_DataType.TF_FLOAT)
4645
{
4746
if (!dtype.is_floating())
4847
throw new TypeError("Cannot create initializer for non-floating point type.");
49-
if (!new string[] { "FAN_IN", "FAN_OUT", "FAN_AVG" }.Contains(mode))
50-
throw new TypeError($"Unknown {mode} %s [FAN_IN, FAN_OUT, FAN_AVG]");
48+
if (!new string[] { "fan_in", "fan_out", "fan_avg" }.Contains(mode))
49+
throw new TypeError($"Unknown {mode} %s [fan_in, fan_out, fan_avg]");
50+
if(distribution == "normal")
51+
{
52+
distribution = "truncated_normal";
53+
}
54+
if(!new string[] { "uniform", "truncated_normal", "untruncated_normal" }.Contains(distribution))
55+
{
56+
throw new ValueError($"Invalid `distribution` argument: {distribution}");
57+
}
5158

52-
if (factor < 0)
59+
if (scale <= 0)
5360
throw new ValueError("`scale` must be positive float.");
5461

55-
_scale = factor;
62+
_scale = scale;
5663
_mode = mode;
5764
_seed = seed;
5865
_dtype = dtype;
59-
_uniform = uniform;
66+
_distribution = distribution;
6067

6168
_config = new();
6269
_config["scale"] = _scale;
@@ -72,23 +79,28 @@ public Tensor Apply(InitializerArgs args)
7279

7380
float n = 0;
7481
var (fan_in, fan_out) = _compute_fans(args.Shape);
75-
if (_mode == "FAN_IN")
76-
n = fan_in;
77-
else if (_mode == "FAN_OUT")
78-
n = fan_out;
79-
else if (_mode == "FAN_AVG")
80-
n = (fan_in + fan_out) / 2.0f;
82+
var scale = this._scale;
83+
if (_mode == "fan_in")
84+
scale /= Math.Max(1.0f, fan_in);
85+
else if (_mode == "fan_out")
86+
scale /= Math.Max(1.0f, fan_out);
87+
else
88+
scale /= Math.Max(1.0f, (fan_in + fan_out) / 2);
8189

82-
if (_uniform)
90+
if(_distribution == "truncated_normal")
8391
{
84-
var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n));
85-
return random_ops.random_uniform(args.Shape, -limit, limit, args.DType);
92+
var stddev = Math.Sqrt(scale) / .87962566103423978f;
93+
return random_ops.truncated_normal(args.Shape, 0.0f, (float)stddev, args.DType);
94+
}
95+
else if(_distribution == "untruncated_normal")
96+
{
97+
var stddev = Math.Sqrt(scale);
98+
return random_ops.random_normal(args.Shape, 0.0f, (float)stddev, args.DType);
8699
}
87100
else
88101
{
89-
var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n));
90-
return random_ops.truncated_normal(args.Shape, 0.0f, trunc_stddev, args.DType,
91-
seed: _seed);
102+
var limit = (float)Math.Sqrt(scale * 3.0f);
103+
return random_ops.random_uniform(args.Shape, -limit, limit, args.DType);
92104
}
93105
}
94106

src/TensorFlowNET.Core/Operations/gen_ops.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -29543,7 +29543,7 @@ public static (Tensor e, Tensor v) self_adjoint_eig_v2(Tensor input, bool? compu
2954329543
/// if &amp;lt; 0, <c>scale * features</c> otherwise.
2954429544
///
2954529545
/// To be used together with
29546-
/// <c>initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')</c>.
29546+
/// <c>initializer = tf.variance_scaling_initializer(scale=1.0, mode='fan_in')</c>.
2954729547
/// For correct dropout, use <c>tf.contrib.nn.alpha_dropout</c>.
2954829548
///
2954929549
/// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)

src/TensorFlowNET.Keras/InitializersApi.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public partial class InitializersApi : IInitializersApi
2727
/// <returns></returns>
2828
public IInitializer HeNormal(int? seed = null)
2929
{
30-
return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed);
30+
return new VarianceScaling(scale: 2.0f, mode: "fan_in", seed: seed);
3131
}
3232

3333
public IInitializer Orthogonal(float gain = 1.0f, int? seed = null)

src/TensorFlowNET.Keras/Layers/Convolution/Conv1D.cs

+38-1
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,46 @@ namespace Tensorflow.Keras.Layers
2020
{
2121
public class Conv1D : Convolutional
2222
{
23-
public Conv1D(Conv1DArgs args) : base(args)
23+
public Conv1D(Conv1DArgs args) : base(InitializeUndefinedArgs(args))
2424
{
2525

2626
}
27+
28+
private static Conv1DArgs InitializeUndefinedArgs(Conv1DArgs args)
29+
{
30+
if(args.Rank == 0)
31+
{
32+
args.Rank = 1;
33+
}
34+
if(args.Strides is null)
35+
{
36+
args.Strides = 1;
37+
}
38+
if (string.IsNullOrEmpty(args.Padding))
39+
{
40+
args.Padding = "valid";
41+
}
42+
if (string.IsNullOrEmpty(args.DataFormat))
43+
{
44+
args.DataFormat = "channels_last";
45+
}
46+
if(args.DilationRate == 0)
47+
{
48+
args.DilationRate = 1;
49+
}
50+
if(args.Groups == 0)
51+
{
52+
args.Groups = 1;
53+
}
54+
if(args.KernelInitializer is null)
55+
{
56+
args.KernelInitializer = tf.glorot_uniform_initializer;
57+
}
58+
if(args.BiasInitializer is null)
59+
{
60+
args.BiasInitializer = tf.zeros_initializer;
61+
}
62+
return args;
63+
}
2764
}
2865
}

src/TensorFlowNET.Keras/Layers/Convolution/Conv2D.cs

+34-1
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,42 @@ namespace Tensorflow.Keras.Layers
2020
{
2121
public class Conv2D : Convolutional
2222
{
23-
public Conv2D(Conv2DArgs args) : base(args)
23+
public Conv2D(Conv2DArgs args) : base(InitializeUndefinedArgs(args))
2424
{
2525

2626
}
27+
28+
private static Conv2DArgs InitializeUndefinedArgs(Conv2DArgs args)
29+
{
30+
if(args.Rank == 0)
31+
{
32+
args.Rank = 2;
33+
}
34+
if (args.Strides is null)
35+
{
36+
args.Strides = (1, 1);
37+
}
38+
if (string.IsNullOrEmpty(args.Padding))
39+
{
40+
args.Padding = "valid";
41+
}
42+
if (args.DilationRate == 0)
43+
{
44+
args.DilationRate = (1, 1);
45+
}
46+
if (args.Groups == 0)
47+
{
48+
args.Groups = 1;
49+
}
50+
if (args.KernelInitializer is null)
51+
{
52+
args.KernelInitializer = tf.glorot_uniform_initializer;
53+
}
54+
if (args.BiasInitializer is null)
55+
{
56+
args.BiasInitializer = tf.zeros_initializer;
57+
}
58+
return args;
59+
}
2760
}
2861
}

0 commit comments

Comments
 (0)