Skip to content

Add MultiImageFolder dataset #1345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
44 changes: 44 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,50 @@ def test_imagefolder(self):
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)

def test_multiimagefolder(self):
with get_tmp_dir() as root:
# make fake dataset
os.makedirs(os.path.join(root, 'a'))
a = []
for filename in ('a1.png', 'a2.png', 'a3.png'):
result = Image.fromarray((np.random.rand(20, 20) * 255).astype(np.uint8))
result.save(os.path.join(root, 'a', filename))
a.append(os.path.join(root, 'a', filename))
os.makedirs(os.path.join(root, 'b'))
b = []
for filename in ('b1.png', 'b2.png', 'b3.png'):
result = Image.fromarray((np.random.rand(20, 20) * 128).astype(np.uint8))
result.save(os.path.join(root, 'b', filename))
b.append(os.path.join(root, 'b', filename))

true_samples = list(zip(a, b))

directories = [os.path.join(root, 'a'), os.path.join(root, 'b'), ]
dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x)

# test if all images were detected correctly and in the proper order
self.assertEqual(len(true_samples), len(dataset.samples))
for i, j in zip(true_samples, dataset.samples):
self.assertEqual(i, j)

# test if the datasets outputs all images correctly
for i in range(len(dataset)):
self.assertEqual([true_samples[i][0], true_samples[i][1]], dataset[i])

# redo all tests with specified valid image files
dataset = torchvision.datasets.MultiImageFolder(directories=directories, loader=lambda x: x,
is_valid_file=lambda x: '3.png' in x)
true_samples = [true_samples[2]]

# test if all images were detected correctly and in the proper order
self.assertEqual(len(true_samples), len(dataset.samples))
for i, j in zip(true_samples, dataset.samples):
self.assertEqual(i, j)

# test if the datasets outputs all images correctly
for i in range(len(dataset)):
self.assertEqual([true_samples[i][0], true_samples[i][1]], dataset[i])

@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_mnist(self, mock_download_extract):
num_examples = 30
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder, DatasetFolder
from .folder import ImageFolder, DatasetFolder, MultiImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
Expand All @@ -24,7 +24,7 @@
from .ucf101 import UCF101

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'ImageFolder', 'DatasetFolder', 'MultiImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
Expand Down
101 changes: 101 additions & 0 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import os.path
import sys
import torch.utils.data as data


def has_file_allowed_extension(filename, extensions):
Expand Down Expand Up @@ -148,6 +149,7 @@ def __len__(self):


IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
REPR_INDENT = 4


def pil_loader(path):
Expand Down Expand Up @@ -208,3 +210,102 @@ def __init__(self, root, transform=None, target_transform=None,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples


def make_dataset_noclass(dir, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
if not os.path.isdir(dir):
raise ValueError(dir + " is not a directory")
for root, _, fnames in sorted(os.walk(dir)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
images.append(path)
return images


class MultiImageFolder(data.Dataset):
"""A dataset to load multiple images from separate directories at each iteration.
The directories must be arranged in this way: ::

dir1/123.png
dir1/456.png
dir1/789.png
dir1/123.png

dir2/abc.png
dir2/def.png
dir2/ghi.png
dir2/jkl.png

Args:
directories (list): List of directories where the images are. They must all contain the same number of images.
transforms (list of callable, optional): A list of function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid_file (used to check of corrupt files)

Attributes:
samples (list): List of image path lists [image path 1, image path 2, ...] len[samples][x] = len(directories)
"""

def __init__(self, directories, transforms=None, loader=default_loader, is_valid_file=None):
self.directories = [os.path.expanduser(a) for a in directories]
self.transforms = transforms
if self.transforms is not None:
if len(self.directories) != len(self.transforms):
raise ValueError("There must be exactly one transform per directory or no transform at all.")
self.loader = loader
self.extensions = IMG_EXTENSIONS if is_valid_file is None else None

sampleslist = [make_dataset_noclass(self.directories[i], self.extensions,
is_valid_file) for i in range(len(self.directories))]
for list1 in sampleslist:
if len(list1) != len(sampleslist[0]):
raise ValueError("All directories must contain the same number of images.")
if len(list1) == 0:
raise (RuntimeError("At least one of the directories does not contain any valid file.\n"
"Supported extensions are: " + ",".join(IMG_EXTENSIONS)))
self.samples = list(zip(*sampleslist))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (img1, img2, ...) where each image is taken from each of the
initialization directories in the same order.
"""
paths = self.samples[index]
sample = [self.loader(paths[i]) for i in range(len(paths))]
if self.transforms is not None:
sample = [self.transforms[i](sample[i]) for i in range(len(sample))]
return sample

def __len__(self):
return len(self.samples)

def __repr__(self):
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.directories is not None:
for directory in self.directories:
body.append("Directory location: {}".format(directory))
body += self.extra_repr().splitlines()
if self.transforms is not None:
for transform in self.transforms:
body += [repr(transform)]
body += self.extra_repr().splitlines()
lines = [head] + [" " * REPR_INDENT + line for line in body]
return '\n'.join(lines)

def extra_repr(self):
return ""