Skip to content

SegmentationDataset [feature proposal and demo] #1330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,88 @@ def __init__(self, root, transform=None, target_transform=None,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples


def default_label_loader(path):
# assumes that labels are stored as grayscale images, with luminance value encoding class index
with open(path, 'rb') as f:
return Image.open(f).convert('L')


class SegmentationDataset(VisionDataset):
"""A generic data loader where the images are arranged in this way: ::

root/images/xxx.ext
root/images/nsdf3.ext
root/images/asdxx932_.ext

root/labels/xxx.ext
root/labels/nsdf3.ext
root/labels/asdxx932_.ext

Args:
root (string): Root directory path.
folders (tuple, 2 strings): Names of subfolders containing images and
labels, respectively.
image_loader (callable): A function to load an image given its path.
label_loader (callable): A function to load a label given its path.
transforms (callable, optional): A function/transform that takes in
a tuple of PIL images (image, label) and returns their transformed
versions. Most likely this will be a SegmentationCompose with a number
of generic transforms.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid_file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.

Attributes:
images (list): List of image filenames.
labels (list): List of label filenames.
"""

def __init__(self, root, folders=('images', 'labels'), image_loader=pil_loader,
label_loader=default_label_loader, transforms=None, is_valid_file=None):
super(SegmentationDataset, self).__init__(root, transforms)
# check if the required folders exist
self.images_dir = os.path.join(root, folders[0])
self.labels_dir = os.path.join(root, folders[1])
if not os.path.exists(self.images_dir):
raise RuntimeError("No image folder found in " + self.root + ".")
if not os.path.exists(self.labels_dir):
raise RuntimeError("No label folder found in " + self.root + ".")
# list files in both folders, ensuring their names match (extensions don't matter)
all_images = sorted(os.listdir(self.images_dir))
all_labels = sorted(os.listdir(self.labels_dir))
# TODO: is_valid_file
image_names = [os.path.splitext(path)[0] for path in all_images]
label_names = [os.path.splitext(path)[0] for path in all_labels]
matched_names = set(image_names).intersection(set(label_names))
self.images = [
os.path.join(root, self.images_dir, name)
for name, noext in zip(all_images, image_names)
if noext in matched_names
]
self.labels = [
os.path.join(root, self.labels_dir, name)
for name, noext in zip(all_labels, label_names)
if noext in matched_names
]
# continue with the technicalities
self.image_loader = image_loader
self.label_loader = label_loader

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, label).
"""
image = self.image_loader(self.images[index])
label = self.label_loader(self.labels[index])
if self.transforms is not None:
image, label = self.transforms((image, label))
return (image, label)

def __len__(self):
return len(self.images)