diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 3e6e21729c8..a6f2af387dc 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -149,3 +149,15 @@ Flickr .. autoclass:: Flickr30k :members: __getitem__ :special-members: + +VOC +~~~~~~ + + +.. autoclass:: VOCSegmentation + :members: __getitem__ + :special-members: + +.. autoclass:: VOCDetection + :members: __getitem__ + :special-members: diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 19b6b4baa8a..6ab5722315d 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -11,10 +11,12 @@ from .omniglot import Omniglot from .sbu import SBU from .flickr import Flickr8k, Flickr30k +from .voc import VOCSegmentation, VOCDetection __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', - 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k') + 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', + 'VOCSegmentation', 'VOCDetection') diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py new file mode 100644 index 00000000000..f886d701bca --- /dev/null +++ b/torchvision/datasets/voc.py @@ -0,0 +1,244 @@ +import os +import sys +import tarfile +import collections +import torch.utils.data as data +if sys.version_info[0] == 2: + import xml.etree.cElementTree as ET +else: + import xml.etree.ElementTree as ET + +from PIL import Image +from .utils import download_url, check_integrity + +DATASET_YEAR_DICT = { + '2012': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', + 'filename': 'VOCtrainval_11-May-2012.tar', + 'md5': '6cd6e144f989b92b3379bac3b3de84fd', + 'base_dir': 'VOCdevkit/VOC2012' + }, + '2011': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', + 'filename': 'VOCtrainval_25-May-2011.tar', + 'md5': '6c3384ef61512963050cb5d687e5bf1e', + 'base_dir': 'TrainVal/VOCdevkit/VOC2011' + }, + '2010': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', + 'filename': 'VOCtrainval_03-May-2010.tar', + 'md5': 'da459979d0c395079b5c75ee67908abb', + 'base_dir': 'VOCdevkit/VOC2010' + }, + '2009': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', + 'filename': 'VOCtrainval_11-May-2009.tar', + 'md5': '59065e4b188729180974ef6572f6a212', + 'base_dir': 'VOCdevkit/VOC2009' + }, + '2008': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', + 'filename': 'VOCtrainval_11-May-2012.tar', + 'md5': '2629fa636546599198acfcfbfcf1904a', + 'base_dir': 'VOCdevkit/VOC2008' + }, + '2007': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', + 'filename': 'VOCtrainval_06-Nov-2007.tar', + 'md5': 'c52e279531787c972589f7e41ab4ae64', + 'base_dir': 'VOCdevkit/VOC2007' + } +} + + +class VOCSegmentation(data.Dataset): + """`Pascal VOC `_ Segmentation Dataset. + + Args: + root (string): Root directory of the VOC Dataset. + year (string, optional): The dataset year, supports years 2007 to 2012. + image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, + root, + year='2012', + image_set='train', + download=False, + transform=None, + target_transform=None): + self.root = root + self.year = year + self.url = DATASET_YEAR_DICT[year]['url'] + self.filename = DATASET_YEAR_DICT[year]['filename'] + self.md5 = DATASET_YEAR_DICT[year]['md5'] + self.transform = transform + self.target_transform = target_transform + self.image_set = image_set + base_dir = DATASET_YEAR_DICT[year]['base_dir'] + voc_root = os.path.join(self.root, base_dir) + image_dir = os.path.join(voc_root, 'JPEGImages') + mask_dir = os.path.join(voc_root, 'SegmentationClass') + + if download: + download_extract(self.url, self.root, self.filename, self.md5) + + if not os.path.isdir(voc_root): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') + + split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') + + if not os.path.exists(split_f): + raise ValueError( + 'Wrong image_set entered! Please use image_set="train" ' + 'or image_set="trainval" or image_set="val"') + + with open(os.path.join(split_f), "r") as f: + file_names = [x.strip() for x in f.readlines()] + + self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] + self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] + assert (len(self.images) == len(self.masks)) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is the image segmentation. + """ + img = Image.open(self.images[index]).convert('RGB') + target = Image.open(self.masks[index]) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.images) + + +class VOCDetection(data.Dataset): + """`Pascal VOC `_ Detection Dataset. + + Args: + root (string): Root directory of the VOC Dataset. + year (string, optional): The dataset year, supports years 2007 to 2012. + image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + (default: alphabetic indexing of VOC's 20 classes). + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, required): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, + root, + year='2012', + image_set='train', + download=False, + transform=None, + target_transform=None): + self.root = root + self.year = year + self.url = DATASET_YEAR_DICT[year]['url'] + self.filename = DATASET_YEAR_DICT[year]['filename'] + self.md5 = DATASET_YEAR_DICT[year]['md5'] + self.transform = transform + self.target_transform = target_transform + self.image_set = image_set + + base_dir = DATASET_YEAR_DICT[year]['base_dir'] + voc_root = os.path.join(self.root, base_dir) + image_dir = os.path.join(voc_root, 'JPEGImages') + annotation_dir = os.path.join(voc_root, 'Annotations') + + if download: + download_extract(self.url, self.root, self.filename, self.md5) + + if not os.path.isdir(voc_root): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + splits_dir = os.path.join(voc_root, 'ImageSets/Main') + + split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') + + if not os.path.exists(split_f): + raise ValueError( + 'Wrong image_set entered! Please use image_set="train" ' + 'or image_set="trainval" or image_set="val" or a valid' + 'image_set from the VOC ImageSets/Main folder.') + + with open(os.path.join(split_f), "r") as f: + file_names = [x.strip() for x in f.readlines()] + + self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] + self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names] + assert (len(self.images) == len(self.annotations)) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is a dictionary of the XML tree. + """ + img = Image.open(self.images[index]).convert('RGB') + target = self.parse_voc_xml( + ET.parse(self.annotations[index]).getroot()) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.images) + + def parse_voc_xml(self, node): + voc_dict = {} + children = list(node) + if children: + def_dic = collections.defaultdict(list) + for dc in map(self.parse_voc_xml, children): + for ind, v in dc.items(): + def_dic[ind].append(v) + voc_dict = { + node.tag: + {ind: v[0] if len(v) == 1 else v + for ind, v in def_dic.items()} + } + if node.text: + text = node.text.strip() + if not children: + voc_dict[node.tag] = text + return voc_dict + + +def download_extract(url, root, filename, md5): + download_url(url, root, filename, md5) + with tarfile.open(os.path.join(root, filename), "r") as tar: + tar.extractall(path=root)