From 9fb847991a1e45c0dbf40fd896b36b6d91953a24 Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Fri, 22 Sep 2023 18:34:08 +0800 Subject: [PATCH] fix: adjust imdb dataset loader for faster loading speed --- src/TensorFlowNET.Keras/Datasets/Imdb.cs | 29 ++++++++++++--------- src/TensorFlowNET.Keras/Utils/data_utils.cs | 8 +++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index 1c980518..4d6df913 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -112,35 +112,39 @@ public DatasetPass load_data( if (start_char != null) { - int[,] new_x_train_array = new int[x_train_array.GetLength(0), x_train_array.GetLength(1) + 1]; - for (var i = 0; i < x_train_array.GetLength(0); i++) + var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1)); + int[,] new_x_train_array = new int[d1, d2 + 1]; + for (var i = 0; i < d1; i++) { new_x_train_array[i, 0] = (int)start_char; - Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1)); + Array.Copy(x_train_array, i * d2, new_x_train_array, i * (d2 + 1) + 1, d2); } - int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1]; - for (var i = 0; i < x_test_array.GetLength(0); i++) + (d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1)); + int[,] new_x_test_array = new int[d1, d2 + 1]; + for (var i = 0; i < d1; i++) { new_x_test_array[i, 0] = (int)start_char; - Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1)); + Array.Copy(x_test_array, i * d2, new_x_test_array, i * (d2 + 1) + 1, d2); } x_train_array = new_x_train_array; x_test_array = new_x_test_array; } else if (index_from != 0) { - for (var i = 0; i < x_train_array.GetLength(0); i++) + var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1)); + for (var i = 0; i < d1; i++) { - for (var j = 0; j < x_train_array.GetLength(1); j++) + for (var j = 0; j < d2; j++) { if (x_train_array[i, j] == 0) break; x_train_array[i, j] += index_from; } } - for (var i = 0; i < x_test_array.GetLength(0); i++) + (d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1)); + for (var i = 0; i < d1; i++) { - for (var j = 0; j < x_test_array.GetLength(1); j++) + for (var j = 0; j < d2; j++) { if (x_test_array[i, j] == 0) break; @@ -169,9 +173,10 @@ public DatasetPass load_data( if (num_words == null) { + var (d1, d2) = (xs_array.GetLength(0), xs_array.GetLength(1)); num_words = 0; - for (var i = 0; i < xs_array.GetLength(0); i++) - for (var j = 0; j < xs_array.GetLength(1); j++) + for (var i = 0; i < d1; i++) + for (var j = 0; j < d2; j++) num_words = max((int)num_words, (int)xs_array[i, j]); } diff --git a/src/TensorFlowNET.Keras/Utils/data_utils.cs b/src/TensorFlowNET.Keras/Utils/data_utils.cs index e6db0ef7..b0bc1554 100644 --- a/src/TensorFlowNET.Keras/Utils/data_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/data_utils.cs @@ -53,15 +53,17 @@ public static (int[,], long[]) _remove_long_seq(int maxlen, int[,] seq, long[] l new_seq, new_label: shortened lists for `seq` and `label`. */ + var nRow = seq.GetLength(0); + var nCol = seq.GetLength(1); List new_seq = new List(); List new_label = new List(); - for (var i = 0; i < seq.GetLength(0); i++) + for (var i = 0; i < nRow; i++) { - if (maxlen < seq.GetLength(1) && seq[i, maxlen] != 0) + if (maxlen < nCol && seq[i, maxlen] != 0) continue; int[] sentence = new int[maxlen]; - for (var j = 0; j < maxlen && j < seq.GetLength(1); j++) + for (var j = 0; j < maxlen && j < nCol; j++) { sentence[j] = seq[i, j]; }