Skip to content

Commit 5e60a13

Browse files
authored
Merge pull request #1154 from Beacontownfc/mybranch3
Fix: model.load_weights
2 parents fa2d2dc + 8b17b14 commit 5e60a13

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/TensorFlowNET.Keras/Saving/hdf5_format.cs

+9-7
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,8 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
133133
long g = H5G.open(f, name);
134134
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
135135
foreach (var i_ in weight_names)
136-
{
137-
var vm = Regex.Replace(i_, "/", "$");
138-
vm = i_.Split('/')[0] + "/$" + vm.Substring(i_.Split('/')[0].Length + 1, i_.Length - i_.Split('/')[0].Length - 1);
139-
(success, Array result) = Hdf5.ReadDataset<float>(g, vm);
136+
{
137+
(success, Array result) = Hdf5.ReadDataset<float>(g, i_);
140138
if (success)
141139
weight_values.Add(np.array(result));
142140
}
@@ -196,9 +194,13 @@ public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
196194
var tensor = val.AsTensor();
197195
if (name.IndexOf("/") > 1)
198196
{
199-
var crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0]));
200-
var _name = Regex.Replace(name.Substring(name.Split('/')[0].Length, name.Length - name.Split('/')[0].Length), "/", "$");
201-
WriteDataset(crDataGroup, _name, tensor);
197+
var crDataGroup = g;
198+
string[] name_split = name.Split('/');
199+
for(int i = 0; i < name_split.Length - 1; i++)
200+
{
201+
crDataGroup = Hdf5.CreateOrOpenGroup(crDataGroup, Hdf5Utils.NormalizedName(name_split[i]));
202+
}
203+
WriteDataset(crDataGroup, name_split[name_split.Length - 1], tensor);
202204
Hdf5.CloseGroup(crDataGroup);
203205
}
204206
else

0 commit comments

Comments
 (0)