From 992bf55dab0273de568e8347d29fdc19e3ad4aa0 Mon Sep 17 00:00:00 2001 From: Beacontownfc <19636977267@qq.com> Date: Sat, 8 Jul 2023 02:39:06 +0000 Subject: [PATCH] fix load_weights --- src/TensorFlowNET.Keras/Saving/hdf5_format.cs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs index b04391be9..8ac9fddf6 100644 --- a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -7,6 +7,8 @@ using static Tensorflow.Binding; using static Tensorflow.KerasApi; using System.Linq; +using System.Text.RegularExpressions; + namespace Tensorflow.Keras.Saving { public class hdf5_format @@ -132,7 +134,9 @@ public static void load_weights_from_hdf5_group(long f, List layers) var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); foreach (var i_ in weight_names) { - (success, Array result) = Hdf5.ReadDataset(g, i_); + var vm = Regex.Replace(i_, "/", "$"); + vm = i_.Split('/')[0] + "/$" + vm.Substring(i_.Split('/')[0].Length + 1, i_.Length - i_.Split('/')[0].Length - 1); + (success, Array result) = Hdf5.ReadDataset(g, vm); if (success) weight_values.Add(np.array(result)); } @@ -193,7 +197,8 @@ public static void save_weights_to_hdf5_group(long f, List layers) if (name.IndexOf("/") > 1) { var crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0])); - WriteDataset(crDataGroup, name.Split('/')[1], tensor); + var _name = Regex.Replace(name.Substring(name.Split('/')[0].Length, name.Length - name.Split('/')[0].Length), "/", "$"); + WriteDataset(crDataGroup, _name, tensor); Hdf5.CloseGroup(crDataGroup); } else