diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst
index 3e6e21729c8..a6f2af387dc 100644
--- a/docs/source/datasets.rst
+++ b/docs/source/datasets.rst
@@ -149,3 +149,15 @@ Flickr
.. autoclass:: Flickr30k
:members: __getitem__
:special-members:
+
+VOC
+~~~~~~
+
+
+.. autoclass:: VOCSegmentation
+ :members: __getitem__
+ :special-members:
+
+.. autoclass:: VOCDetection
+ :members: __getitem__
+ :special-members:
diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py
index 19b6b4baa8a..6ab5722315d 100644
--- a/torchvision/datasets/__init__.py
+++ b/torchvision/datasets/__init__.py
@@ -11,10 +11,12 @@
from .omniglot import Omniglot
from .sbu import SBU
from .flickr import Flickr8k, Flickr30k
+from .voc import VOCSegmentation, VOCDetection
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
- 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k')
+ 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
+ 'VOCSegmentation', 'VOCDetection')
diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py
new file mode 100644
index 00000000000..f886d701bca
--- /dev/null
+++ b/torchvision/datasets/voc.py
@@ -0,0 +1,244 @@
+import os
+import sys
+import tarfile
+import collections
+import torch.utils.data as data
+if sys.version_info[0] == 2:
+ import xml.etree.cElementTree as ET
+else:
+ import xml.etree.ElementTree as ET
+
+from PIL import Image
+from .utils import download_url, check_integrity
+
+DATASET_YEAR_DICT = {
+ '2012': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
+ 'filename': 'VOCtrainval_11-May-2012.tar',
+ 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
+ 'base_dir': 'VOCdevkit/VOC2012'
+ },
+ '2011': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
+ 'filename': 'VOCtrainval_25-May-2011.tar',
+ 'md5': '6c3384ef61512963050cb5d687e5bf1e',
+ 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
+ },
+ '2010': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
+ 'filename': 'VOCtrainval_03-May-2010.tar',
+ 'md5': 'da459979d0c395079b5c75ee67908abb',
+ 'base_dir': 'VOCdevkit/VOC2010'
+ },
+ '2009': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
+ 'filename': 'VOCtrainval_11-May-2009.tar',
+ 'md5': '59065e4b188729180974ef6572f6a212',
+ 'base_dir': 'VOCdevkit/VOC2009'
+ },
+ '2008': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
+ 'filename': 'VOCtrainval_11-May-2012.tar',
+ 'md5': '2629fa636546599198acfcfbfcf1904a',
+ 'base_dir': 'VOCdevkit/VOC2008'
+ },
+ '2007': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
+ 'filename': 'VOCtrainval_06-Nov-2007.tar',
+ 'md5': 'c52e279531787c972589f7e41ab4ae64',
+ 'base_dir': 'VOCdevkit/VOC2007'
+ }
+}
+
+
+class VOCSegmentation(data.Dataset):
+ """`Pascal VOC `_ Segmentation Dataset.
+
+ Args:
+ root (string): Root directory of the VOC Dataset.
+ year (string, optional): The dataset year, supports years 2007 to 2012.
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self,
+ root,
+ year='2012',
+ image_set='train',
+ download=False,
+ transform=None,
+ target_transform=None):
+ self.root = root
+ self.year = year
+ self.url = DATASET_YEAR_DICT[year]['url']
+ self.filename = DATASET_YEAR_DICT[year]['filename']
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
+ self.transform = transform
+ self.target_transform = target_transform
+ self.image_set = image_set
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
+ voc_root = os.path.join(self.root, base_dir)
+ image_dir = os.path.join(voc_root, 'JPEGImages')
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
+
+ if download:
+ download_extract(self.url, self.root, self.filename, self.md5)
+
+ if not os.path.isdir(voc_root):
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
+
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
+
+ if not os.path.exists(split_f):
+ raise ValueError(
+ 'Wrong image_set entered! Please use image_set="train" '
+ 'or image_set="trainval" or image_set="val"')
+
+ with open(os.path.join(split_f), "r") as f:
+ file_names = [x.strip() for x in f.readlines()]
+
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
+ assert (len(self.images) == len(self.masks))
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is the image segmentation.
+ """
+ img = Image.open(self.images[index]).convert('RGB')
+ target = Image.open(self.masks[index])
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.images)
+
+
+class VOCDetection(data.Dataset):
+ """`Pascal VOC `_ Detection Dataset.
+
+ Args:
+ root (string): Root directory of the VOC Dataset.
+ year (string, optional): The dataset year, supports years 2007 to 2012.
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ (default: alphabetic indexing of VOC's 20 classes).
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, required): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self,
+ root,
+ year='2012',
+ image_set='train',
+ download=False,
+ transform=None,
+ target_transform=None):
+ self.root = root
+ self.year = year
+ self.url = DATASET_YEAR_DICT[year]['url']
+ self.filename = DATASET_YEAR_DICT[year]['filename']
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
+ self.transform = transform
+ self.target_transform = target_transform
+ self.image_set = image_set
+
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
+ voc_root = os.path.join(self.root, base_dir)
+ image_dir = os.path.join(voc_root, 'JPEGImages')
+ annotation_dir = os.path.join(voc_root, 'Annotations')
+
+ if download:
+ download_extract(self.url, self.root, self.filename, self.md5)
+
+ if not os.path.isdir(voc_root):
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ splits_dir = os.path.join(voc_root, 'ImageSets/Main')
+
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
+
+ if not os.path.exists(split_f):
+ raise ValueError(
+ 'Wrong image_set entered! Please use image_set="train" '
+ 'or image_set="trainval" or image_set="val" or a valid'
+ 'image_set from the VOC ImageSets/Main folder.')
+
+ with open(os.path.join(split_f), "r") as f:
+ file_names = [x.strip() for x in f.readlines()]
+
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+ self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
+ assert (len(self.images) == len(self.annotations))
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a dictionary of the XML tree.
+ """
+ img = Image.open(self.images[index]).convert('RGB')
+ target = self.parse_voc_xml(
+ ET.parse(self.annotations[index]).getroot())
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.images)
+
+ def parse_voc_xml(self, node):
+ voc_dict = {}
+ children = list(node)
+ if children:
+ def_dic = collections.defaultdict(list)
+ for dc in map(self.parse_voc_xml, children):
+ for ind, v in dc.items():
+ def_dic[ind].append(v)
+ voc_dict = {
+ node.tag:
+ {ind: v[0] if len(v) == 1 else v
+ for ind, v in def_dic.items()}
+ }
+ if node.text:
+ text = node.text.strip()
+ if not children:
+ voc_dict[node.tag] = text
+ return voc_dict
+
+
+def download_extract(url, root, filename, md5):
+ download_url(url, root, filename, md5)
+ with tarfile.open(os.path.join(root, filename), "r") as tar:
+ tar.extractall(path=root)