Skip to content

Commit 12e3f54

Browse files
authored
Merge pull request #1148 from Wanglongzhi2001/master
fix: make the initialization of the layer's name correct
2 parents dffc465 + f6f792a commit 12e3f54

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

src/TensorFlowNET.Keras/Utils/generic_utils.cs

+9-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License.
2929
using Tensorflow.Keras.Layers;
3030
using Tensorflow.Keras.Saving;
3131
using Tensorflow.Train;
32+
using System.Text.RegularExpressions;
3233

3334
namespace Tensorflow.Keras.Utils
3435
{
@@ -126,12 +127,15 @@ public static FunctionalConfig deserialize_model_config(JToken json)
126127

127128
public static string to_snake_case(string name)
128129
{
129-
return string.Concat(name.Select((x, i) =>
130+
string intermediate = Regex.Replace(name, "(.)([A-Z][a-z0-9]+)", "$1_$2");
131+
string insecure = Regex.Replace(intermediate, "([a-z])([A-Z])", "$1_$2").ToLower();
132+
133+
if (insecure[0] != '_')
130134
{
131-
return i > 0 && char.IsUpper(x) && !Char.IsDigit(name[i - 1]) ?
132-
"_" + x.ToString() :
133-
x.ToString();
134-
})).ToLower();
135+
return insecure;
136+
}
137+
138+
return "private" + insecure;
135139
}
136140

137141
/// <summary>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using Tensorflow.Keras.Layers;
3+
using static Tensorflow.Binding;
4+
using static Tensorflow.KerasApi;
5+
6+
namespace Tensorflow.Keras.UnitTest
7+
{
8+
[TestClass]
9+
public class InitLayerNameTest
10+
{
11+
[TestMethod]
12+
public void RNNLayerNameTest()
13+
{
14+
var simpleRnnCell = keras.layers.SimpleRNNCell(1);
15+
Assert.AreEqual("simple_rnn_cell", simpleRnnCell.Name);
16+
var simpleRnn = keras.layers.SimpleRNN(2);
17+
Assert.AreEqual("simple_rnn", simpleRnn.Name);
18+
var lstmCell = keras.layers.LSTMCell(2);
19+
Assert.AreEqual("lstm_cell", lstmCell.Name);
20+
var lstm = keras.layers.LSTM(3);
21+
Assert.AreEqual("lstm", lstm.Name);
22+
}
23+
24+
[TestMethod]
25+
public void ConvLayerNameTest()
26+
{
27+
var conv2d = keras.layers.Conv2D(8, activation: "linear");
28+
Assert.AreEqual("conv2d", conv2d.Name);
29+
var conv2dTranspose = keras.layers.Conv2DTranspose(8);
30+
Assert.AreEqual("conv2d_transpose", conv2dTranspose.Name);
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)