diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index 49fc79251..081c26cb9 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -116,23 +116,13 @@ public DatasetPass load_data( for (var i = 0; i < x_train_array.GetLength(0); i++) { new_x_train_array[i, 0] = (int)start_char; - for (var j = 0; j < x_train_array.GetLength(1); j++) - { - if (x_train_array[i, j] == 0) - break; - new_x_train_array[i, j + 1] = x_train_array[i, j]; - } + Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1)); } int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1]; for (var i = 0; i < x_test_array.GetLength(0); i++) { new_x_test_array[i, 0] = (int)start_char; - for (var j = 0; j < x_test_array.GetLength(1); j++) - { - if (x_test_array[i, j] == 0) - break; - new_x_test_array[i, j + 1] = x_test_array[i, j]; - } + Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1)); } x_train_array = new_x_train_array; x_test_array = new_x_test_array; @@ -163,15 +153,19 @@ public DatasetPass load_data( { maxlen = max(x_train_array.GetLength(1), x_test_array.GetLength(1)); } - (x_train, labels_train) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array); - (x_test, labels_test) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array); - if (x_train.size == 0 || x_test.size == 0) + (x_train_array, labels_train_array) = data_utils._remove_long_seq((int)maxlen, x_train_array, labels_train_array); + (x_test_array, labels_test_array) = data_utils._remove_long_seq((int)maxlen, x_test_array, labels_test_array); + if (x_train_array.Length == 0 || x_test_array.Length == 0) throw new ValueError("After filtering for sequences shorter than maxlen=" + $"{maxlen}, no sequence was kept. Increase maxlen."); - var xs = np.concatenate(new[] { x_train, x_test }); - var labels = np.concatenate(new[] { labels_train, labels_test }); - var xs_array = (int[,])xs.ToMultiDimArray(); + int[,] xs_array = new int[x_train_array.GetLength(0) + x_test_array.GetLength(0), (int)maxlen]; + Array.Copy(x_train_array, xs_array, x_train_array.Length); + Array.Copy(x_test_array, 0, xs_array, x_train_array.Length, x_train_array.Length); + + long[] labels_array = new long[labels_train_array.Length + labels_test_array.Length]; + Array.Copy(labels_train_array, labels_array, labels_train_array.Length); + Array.Copy(labels_test_array, 0, labels_array, labels_train_array.Length, labels_test_array.Length); if (num_words == null) { @@ -197,7 +191,7 @@ public DatasetPass load_data( new_xs_array[i, j] = (int)oov_char; } } - xs = new NDArray(new_xs_array); + xs_array = new_xs_array; } else { @@ -211,19 +205,19 @@ public DatasetPass load_data( new_xs_array[i, k++] = xs_array[i, j]; } } - xs = new NDArray(new_xs_array); + xs_array = new_xs_array; } - var idx = len(x_train); - x_train = xs[$"0:{idx}"]; - x_test = xs[$"{idx}:"]; - var y_train = labels[$"0:{idx}"]; - var y_test = labels[$"{idx}:"]; + Array.Copy(xs_array, x_train_array, x_train_array.Length); + Array.Copy(xs_array, x_train_array.Length, x_test_array, 0, x_train_array.Length); + + Array.Copy(labels_array, labels_train_array, labels_train_array.Length); + Array.Copy(labels_array, labels_train_array.Length, labels_test_array, 0, labels_test_array.Length); return new DatasetPass { - Train = (x_train, y_train), - Test = (x_test, y_test) + Train = (x_train_array, labels_train_array), + Test = (x_test_array, labels_test_array) }; } diff --git a/src/TensorFlowNET.Keras/Utils/data_utils.cs b/src/TensorFlowNET.Keras/Utils/data_utils.cs index 57ae76695..e6db0ef72 100644 --- a/src/TensorFlowNET.Keras/Utils/data_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/data_utils.cs @@ -40,7 +40,7 @@ public static string get_file(string fname, string origin, return datadir; } - public static (NDArray, NDArray) _remove_long_seq(int maxlen, NDArray seq, NDArray label) + public static (int[,], long[]) _remove_long_seq(int maxlen, int[,] seq, long[] label) { /*Removes sequences that exceed the maximum length. @@ -56,19 +56,17 @@ public static (NDArray, NDArray) _remove_long_seq(int maxlen, NDArray seq, NDArr List new_seq = new List(); List new_label = new List(); - var seq_array = (int[,])seq.ToMultiDimArray(); - var label_array = (long[])label.ToArray(); - for (var i = 0; i < seq_array.GetLength(0); i++) + for (var i = 0; i < seq.GetLength(0); i++) { - if (maxlen < seq_array.GetLength(1) && seq_array[i,maxlen] != 0) + if (maxlen < seq.GetLength(1) && seq[i, maxlen] != 0) continue; int[] sentence = new int[maxlen]; - for (var j = 0; j < maxlen && j < seq_array.GetLength(1); j++) + for (var j = 0; j < maxlen && j < seq.GetLength(1); j++) { - sentence[j] = seq_array[i, j]; + sentence[j] = seq[i, j]; } new_seq.Add(sentence); - new_label.Add(label_array[i]); + new_label.Add(label[i]); } int[,] new_seq_array = new int[new_seq.Count, maxlen];