diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 71a070a135f..6542064fcba 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -208,3 +208,88 @@ def __init__(self, root, transform=None, target_transform=None, target_transform=target_transform, is_valid_file=is_valid_file) self.imgs = self.samples + + +def default_label_loader(path): + # assumes that labels are stored as grayscale images, with luminance value encoding class index + with open(path, 'rb') as f: + return Image.open(f).convert('L') + + +class SegmentationDataset(VisionDataset): + """A generic data loader where the images are arranged in this way: :: + + root/images/xxx.ext + root/images/nsdf3.ext + root/images/asdxx932_.ext + + root/labels/xxx.ext + root/labels/nsdf3.ext + root/labels/asdxx932_.ext + + Args: + root (string): Root directory path. + folders (tuple, 2 strings): Names of subfolders containing images and + labels, respectively. + image_loader (callable): A function to load an image given its path. + label_loader (callable): A function to load a label given its path. + transforms (callable, optional): A function/transform that takes in + a tuple of PIL images (image, label) and returns their transformed + versions. Most likely this will be a SegmentationCompose with a number + of generic transforms. + is_valid_file (callable, optional): A function that takes path of an Image file + and check if the file is a valid_file (used to check of corrupt files) + both extensions and is_valid_file should not be passed. + + Attributes: + images (list): List of image filenames. + labels (list): List of label filenames. + """ + + def __init__(self, root, folders=('images', 'labels'), image_loader=pil_loader, + label_loader=default_label_loader, transforms=None, is_valid_file=None): + super(SegmentationDataset, self).__init__(root, transforms) + # check if the required folders exist + self.images_dir = os.path.join(root, folders[0]) + self.labels_dir = os.path.join(root, folders[1]) + if not os.path.exists(self.images_dir): + raise RuntimeError("No image folder found in " + self.root + ".") + if not os.path.exists(self.labels_dir): + raise RuntimeError("No label folder found in " + self.root + ".") + # list files in both folders, ensuring their names match (extensions don't matter) + all_images = sorted(os.listdir(self.images_dir)) + all_labels = sorted(os.listdir(self.labels_dir)) + # TODO: is_valid_file + image_names = [os.path.splitext(path)[0] for path in all_images] + label_names = [os.path.splitext(path)[0] for path in all_labels] + matched_names = set(image_names).intersection(set(label_names)) + self.images = [ + os.path.join(root, self.images_dir, name) + for name, noext in zip(all_images, image_names) + if noext in matched_names + ] + self.labels = [ + os.path.join(root, self.labels_dir, name) + for name, noext in zip(all_labels, label_names) + if noext in matched_names + ] + # continue with the technicalities + self.image_loader = image_loader + self.label_loader = label_loader + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, label). + """ + image = self.image_loader(self.images[index]) + label = self.label_loader(self.labels[index]) + if self.transforms is not None: + image, label = self.transforms((image, label)) + return (image, label) + + def __len__(self): + return len(self.images)