From 222638b10de5781a11697270f9be24cd453c9c3f Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 14:58:25 +0700 Subject: [PATCH 01/10] Add MultiImageFolder dataset Signed-off-by: Sebastien ESKENAZI --- test/test_datasets.py | 44 +++++++++++++++ torchvision/datasets/folder.py | 100 +++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+) diff --git a/test/test_datasets.py b/test/test_datasets.py index f4ef4721370..3bc7f4e6a3d 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -81,6 +81,50 @@ def test_imagefolder(self): outputs = sorted([dataset[i] for i in range(len(dataset))]) self.assertEqual(imgs, outputs) + def test_multiimagefolder(self): + with get_tmp_dir() as root: + # make fake dataset + os.makedirs(os.path.join(root, 'a')) + a = [] + for filename in ('a1.png', 'a2.png', 'a3.png'): + result = Image.fromarray((np.random.rand(20, 20) * 255).astype(numpy.uint8)) + result.save(os.path.join(root, 'a', filename)) + a.append(os.path.join(root, 'a', filename)) + os.makedirs(os.path.join(root, 'b')) + b = [] + for filename in ('b1.png', 'b2.png', 'b3.png'): + result = Image.fromarray((np.random.rand(20, 20) * 128).astype(numpy.uint8)) + result.save(os.path.join(root, 'b', filename)) + b.append(os.path.join(root, 'b', filename)) + + true_samples = list(zip(a, b)) + + directories = [os.path.join(root, 'a'), os.path.join(root, 'b'), ] + dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x) + + # test if all images were detected correctly and in the proper order + self.assertEqual(len(true_samples), len(dataset.samples)) + for i, j in zip(true_samples, dataset.samples): + self.assertEqual(i, j) + + # test if the datasets outputs all images correctly + for i in range(len(dataset)): + self.assertEqual(true_samples[i], dataset[i]) + + # redo all tests with specified valid image files + dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x, + is_valid_file=lambda x: '3' in x) + true_samples = [['a3.png', 'b3.png']] + + # test if all images were detected correctly and in the proper order + self.assertEqual(len(true_samples), len(dataset.samples)) + for i, j in zip(true_samples, dataset.samples): + self.assertEqual(i, j) + + # test if the datasets outputs all images correctly + for i in range(len(dataset)): + self.assertEqual(true_samples[i], dataset[i]) + @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_mnist(self, mock_download_extract): num_examples = 30 diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index f0546daa93d..7c9c53cc8db 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -5,6 +5,7 @@ import os import os.path import sys +import torch.utils.data as data def has_file_allowed_extension(filename, extensions): @@ -148,6 +149,7 @@ def __len__(self): IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') +REPR_INDENT = 4 def pil_loader(path): @@ -208,3 +210,101 @@ def __init__(self, root, transform=None, target_transform=None, target_transform=target_transform, is_valid_file=is_valid_file) self.imgs = self.samples + + +def make_dataset_noclass(dir, extensions=None, is_valid_file=None): + images = [] + dir = os.path.expanduser(dir) + if not ((extensions is None) ^ (is_valid_file is None)): + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + def is_valid_file(x): + return has_file_allowed_extension(x, extensions) + if not os.path.isdir(dir): + raise ValueError(dir+" is not a directory") + for root, _, fnames in sorted(os.walk(dir)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + images.append(path) + return images + + +class MultiImageFolder(data.Dataset): + """A dataset to load multiple images from separate directories at each iteration. + The directories must be arranged in this way: :: + + dir1/123.png + dir1/456.png + dir1/789.png + dir1/123.png + + dir2/abc.png + dir2/def.png + dir2/ghi.png + dir2/jkl.png + + Args: + directories (list): List of directories where the images are. They must all contain the same number of images. + transforms (list of callable, optional): A list of function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + loader (callable, optional): A function to load an image given its path. + is_valid_file (callable, optional): A function that takes path of an Image file + and check if the file is a valid_file (used to check of corrupt files) + + Attributes: + samples (list): List of image path lists [image path 1, image path 2, ...] len[samples][x] = len(directories) + """ + + def __init__(self, directories, transforms=None, loader=default_loader, is_valid_file=None): + self.directories = [os.path.expanduser(a) for a in directories] + self.transforms = transforms + if self.transforms is not None: + if len(self.directories) != len(self.transforms): + raise ValueError("There must be exactly one transform per directory or no transform at all.") + self.loader = loader + self.extensions = IMG_EXTENSIONS if is_valid_file is None else None + + sampleslist = [make_dataset_noclass(self.directories[i], self.extensions, is_valid_file) for i in range(len(self.directories))] + for list1 in sampleslist: + if len(list1) != len(sampleslist[0]): + raise ValueError("All directories must contain the same number of images.") + if len(list1) == 0: + raise (RuntimeError("At least one of the directories does not contain any valid file.\n" + "Supported extensions are: " + ",".join(extensions))) + self.samples = list(zip(*sampleslist)) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (img1, img2, ...) where each image is taken from each of the + initialization directories in the same order. + """ + paths = self.samples[index] + sample = [self.loader(paths[i]) for i in range(len(paths))] + if self.transforms is not None: + sample = [self.transforms[i](sample[i]) for i in range(len(sample))] + return sample + + def __len__(self): + return len(self.samples) + + def __repr__(self): + head = "Dataset " + self.__class__.__name__ + body = ["Number of datapoints: {}".format(self.__len__())] + if self.directories is not None: + for directory in self.directories: + body.append("Directory location: {}".format(directory)) + body += self.extra_repr().splitlines() + if self.transforms is not None: + for transform in self.transforms: + body += [repr(transform)] + body += self.extra_repr().splitlines() + lines = [head] + [" " * REPR_INDENT + line for line in body] + return '\n'.join(lines) + + def extra_repr(self): + return "" From 11009dd5c48e448c8de1256d953d4dbde2d5fc0c Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 15:15:07 +0700 Subject: [PATCH 02/10] Fix test bug with numpy np Signed-off-by: Sebastien ESKENAZI --- test/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 3bc7f4e6a3d..30d3963df34 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -87,13 +87,13 @@ def test_multiimagefolder(self): os.makedirs(os.path.join(root, 'a')) a = [] for filename in ('a1.png', 'a2.png', 'a3.png'): - result = Image.fromarray((np.random.rand(20, 20) * 255).astype(numpy.uint8)) + result = Image.fromarray((np.random.rand(20, 20) * 255).astype(np.uint8)) result.save(os.path.join(root, 'a', filename)) a.append(os.path.join(root, 'a', filename)) os.makedirs(os.path.join(root, 'b')) b = [] for filename in ('b1.png', 'b2.png', 'b3.png'): - result = Image.fromarray((np.random.rand(20, 20) * 128).astype(numpy.uint8)) + result = Image.fromarray((np.random.rand(20, 20) * 128).astype(np.uint8)) result.save(os.path.join(root, 'b', filename)) b.append(os.path.join(root, 'b', filename)) From e376c3ca0bc58ab6af05ae2d5af25c6458e482dd Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 15:28:49 +0700 Subject: [PATCH 03/10] Add MultiImageFolder class to dataset module Signed-off-by: Sebastien ESKENAZI --- torchvision/datasets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index db5b572a469..7efb0681707 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -24,7 +24,7 @@ from .ucf101 import UCF101 __all__ = ('LSUN', 'LSUNClass', - 'ImageFolder', 'DatasetFolder', 'FakeData', + 'ImageFolder', 'DatasetFolder','MultiImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST', 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', From 1b48e0d593288bd65e3f76199d001796b7050a8c Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 15:41:58 +0700 Subject: [PATCH 04/10] Readd MultiImageFolder to dataset module Signed-off-by: Sebastien ESKENAZI --- torchvision/datasets/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 7efb0681707..280036ff700 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,5 +1,5 @@ from .lsun import LSUN, LSUNClass -from .folder import ImageFolder, DatasetFolder +from .folder import ImageFolder, DatasetFolder, MultiImageFolder from .coco import CocoCaptions, CocoDetection from .cifar import CIFAR10, CIFAR100 from .stl10 import STL10 @@ -24,7 +24,7 @@ from .ucf101 import UCF101 __all__ = ('LSUN', 'LSUNClass', - 'ImageFolder', 'DatasetFolder','MultiImageFolder', 'FakeData', + 'ImageFolder', 'DatasetFolder', 'MultiImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST', 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', From 3fe08d011b10420069a912b024cb2818b0ebdd4d Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 15:52:41 +0700 Subject: [PATCH 05/10] fix bugs and PEP complicance Signed-off-by: Sebastien ESKENAZI --- torchvision/datasets/folder.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 7c9c53cc8db..c968a62bc9c 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -221,7 +221,7 @@ def make_dataset_noclass(dir, extensions=None, is_valid_file=None): def is_valid_file(x): return has_file_allowed_extension(x, extensions) if not os.path.isdir(dir): - raise ValueError(dir+" is not a directory") + raise ValueError(dir + " is not a directory") for root, _, fnames in sorted(os.walk(dir)): for fname in sorted(fnames): path = os.path.join(root, fname) @@ -265,13 +265,14 @@ def __init__(self, directories, transforms=None, loader=default_loader, is_valid self.loader = loader self.extensions = IMG_EXTENSIONS if is_valid_file is None else None - sampleslist = [make_dataset_noclass(self.directories[i], self.extensions, is_valid_file) for i in range(len(self.directories))] + sampleslist = [make_dataset_noclass(self.directories[i], self.extensions, + is_valid_file) for i in range(len(self.directories))] for list1 in sampleslist: if len(list1) != len(sampleslist[0]): raise ValueError("All directories must contain the same number of images.") if len(list1) == 0: raise (RuntimeError("At least one of the directories does not contain any valid file.\n" - "Supported extensions are: " + ",".join(extensions))) + "Supported extensions are: " + ",".join(IMG_EXTENSIONS))) self.samples = list(zip(*sampleslist)) def __getitem__(self, index): @@ -280,7 +281,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (img1, img2, ...) where each image is taken from each of the + tuple: (img1, img2, ...) where each image is taken from each of the initialization directories in the same order. """ paths = self.samples[index] From 4322f6b09d4b67e3ceed71a1b31b6400b0e5cbe7 Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 16:27:08 +0700 Subject: [PATCH 06/10] Fix test bugs Signed-off-by: Sebastien ESKENAZI --- test/test_datasets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 30d3963df34..991e0da906a 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -109,12 +109,12 @@ def test_multiimagefolder(self): # test if the datasets outputs all images correctly for i in range(len(dataset)): - self.assertEqual(true_samples[i], dataset[i]) + self.assertEqual([*true_samples[i]], dataset[i]) # redo all tests with specified valid image files dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x, is_valid_file=lambda x: '3' in x) - true_samples = [['a3.png', 'b3.png']] + true_samples = [true_samples[2]] # test if all images were detected correctly and in the proper order self.assertEqual(len(true_samples), len(dataset.samples)) @@ -123,7 +123,7 @@ def test_multiimagefolder(self): # test if the datasets outputs all images correctly for i in range(len(dataset)): - self.assertEqual(true_samples[i], dataset[i]) + self.assertEqual([*true_samples[i]], dataset[i]) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_mnist(self, mock_download_extract): From caa6422b47806c82f5a03ef43a019f03efe597d5 Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 16:54:57 +0700 Subject: [PATCH 07/10] Try to fix weird syntax bug Signed-off-by: Sebastien ESKENAZI --- test/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 991e0da906a..7d85710bdf6 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -113,7 +113,7 @@ def test_multiimagefolder(self): # redo all tests with specified valid image files dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x, - is_valid_file=lambda x: '3' in x) + is_valid_file=lambda x: '3.png' in x) true_samples = [true_samples[2]] # test if all images were detected correctly and in the proper order @@ -123,7 +123,7 @@ def test_multiimagefolder(self): # test if the datasets outputs all images correctly for i in range(len(dataset)): - self.assertEqual([*true_samples[i]], dataset[i]) + self.assertEqual([ *true_samples[i] ], dataset[i]) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_mnist(self, mock_download_extract): From 6ffb2df507b05e7abdcf379847df909ef940985a Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 17:06:46 +0700 Subject: [PATCH 08/10] Fix more test bugs Signed-off-by: Sebastien ESKENAZI --- test/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 7d85710bdf6..9c371c4d1f4 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -109,7 +109,7 @@ def test_multiimagefolder(self): # test if the datasets outputs all images correctly for i in range(len(dataset)): - self.assertEqual([*true_samples[i]], dataset[i]) + self.assertEqual([ *true_samples[i] ], dataset[i]) # redo all tests with specified valid image files dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x, From a44fd86b7b819479543cf296ebcce57d380654b3 Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 17:23:25 +0700 Subject: [PATCH 09/10] Fix syntax bug in MultiImageFolder test Signed-off-by: Sebastien ESKENAZI --- test/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 9c371c4d1f4..7fcaba5d44f 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -109,7 +109,7 @@ def test_multiimagefolder(self): # test if the datasets outputs all images correctly for i in range(len(dataset)): - self.assertEqual([ *true_samples[i] ], dataset[i]) + self.assertEqual([ true_samples[i][0], true_samples[i][1] ], dataset[i]) # redo all tests with specified valid image files dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x, @@ -123,7 +123,7 @@ def test_multiimagefolder(self): # test if the datasets outputs all images correctly for i in range(len(dataset)): - self.assertEqual([ *true_samples[i] ], dataset[i]) + self.assertEqual([ true_samples[i][0], true_samples[i][1] ], dataset[i]) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_mnist(self, mock_download_extract): From 245b699e05d2fc93c6936ac39aba141c76905092 Mon Sep 17 00:00:00 2001 From: Sebastien ESKENAZI Date: Wed, 18 Sep 2019 17:29:56 +0700 Subject: [PATCH 10/10] Fix PEP issues Signed-off-by: Sebastien ESKENAZI --- test/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 7fcaba5d44f..9e96a32d0fd 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -109,7 +109,7 @@ def test_multiimagefolder(self): # test if the datasets outputs all images correctly for i in range(len(dataset)): - self.assertEqual([ true_samples[i][0], true_samples[i][1] ], dataset[i]) + self.assertEqual([true_samples[i][0], true_samples[i][1]], dataset[i]) # redo all tests with specified valid image files dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x, @@ -123,7 +123,7 @@ def test_multiimagefolder(self): # test if the datasets outputs all images correctly for i in range(len(dataset)): - self.assertEqual([ true_samples[i][0], true_samples[i][1] ], dataset[i]) + self.assertEqual([true_samples[i][0], true_samples[i][1]], dataset[i]) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') def test_mnist(self, mock_download_extract):