diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 8b281c8fbb6..edd094a9921 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -13,6 +13,7 @@ from .flickr import Flickr8k, Flickr30k from .voc import VOCSegmentation, VOCDetection from .cityscapes import Cityscapes +from .imagenet import ImageNet from .caltech import Caltech101, Caltech256 from .celeba import CelebA @@ -22,5 +23,5 @@ 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', - 'VOCSegmentation', 'VOCDetection', 'Cityscapes', + 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', 'Caltech101', 'Caltech256', 'CelebA') diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py new file mode 100644 index 00000000000..8a1268944eb --- /dev/null +++ b/torchvision/datasets/imagenet.py @@ -0,0 +1,236 @@ +from __future__ import print_function +import os +import shutil +import torch +from .folder import ImageFolder +from .utils import check_integrity, download_url + +ARCHIVE_DICT = { + 'train': { + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar', + 'md5': '1d675b47d978889d74fa0da5fadfb00e', + }, + 'val': { + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar', + 'md5': '29b22e2961454d5413ddabcf34fc5622', + }, + 'devkit': { + 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz', + 'md5': 'fa75699e90414af021442c21a62c3abf', + } +} + +META_DICT = { + 'filename': 'meta.bin', + 'md5': '7e0d3cf156177e4fc47011cdd30ce706', +} + + +class ImageNet(ImageFolder): + """`ImageNet `_ 2012 Classification Dataset. + + Args: + root (string): Root directory of the ImageNet Dataset. + split (string, optional): The dataset split, supports ``train``, or ``val``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + wnids (list): List of the WordNet IDs. + class_to_idx (dict): Dict with items (wordnet_id, wordnet_id_index). + imgs (list): List of (image path, class_index) tuples + targets (list): The class_index value for each image in the dataset + """ + + def __init__(self, root, split='train', download=False, **kwargs): + root = self.root = os.path.expanduser(root) + self.split = self._verify_split(split) + + if download: + self.download() + wnid_to_classes = self._load_meta_file()[0] + + super(ImageNet, self).__init__(self.split_folder, **kwargs) + self.root = root + + idcs = [idx for _, idx in self.imgs] + self.wnids = self.classes + self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)} + self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] + self.class_to_idx = {cls: idx + for clss, idx in zip(self.classes, idcs) + for cls in clss} + + def download(self): + if not self._check_meta_file_integrity(): + tmpdir = os.path.join(self.root, 'tmp') + + archive_dict = ARCHIVE_DICT['devkit'] + download_and_extract_tar(archive_dict['url'], self.root, + extract_root=tmpdir, + md5=archive_dict['md5']) + devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] + meta = parse_devkit(os.path.join(tmpdir, devkit_folder)) + self._save_meta_file(*meta) + + shutil.rmtree(tmpdir) + + if not os.path.isdir(self.split_folder): + archive_dict = ARCHIVE_DICT[self.split] + download_and_extract_tar(archive_dict['url'], self.root, + extract_root=self.split_folder, + md5=archive_dict['md5']) + + if self.split == 'train': + prepare_train_folder(self.split_folder) + elif self.split == 'val': + val_wnids = self._load_meta_file()[1] + prepare_val_folder(self.split_folder, val_wnids) + else: + msg = ("You set download=True, but a folder '{}' already exist in " + "the root directory. If you want to re-download or re-extract the " + "archive, delete the folder.") + print(msg.format(self.split)) + + @property + def meta_file(self): + return os.path.join(self.root, META_DICT['filename']) + + def _check_meta_file_integrity(self): + return check_integrity(self.meta_file, META_DICT['md5']) + + def _load_meta_file(self): + if self._check_meta_file_integrity(): + return torch.load(self.meta_file) + else: + raise RuntimeError("Meta file not found or corrupted.", + "You can use download=True to create it.") + + def _save_meta_file(self, wnid_to_class, val_wnids): + torch.save((wnid_to_class, val_wnids), self.meta_file) + + def _verify_split(self, split): + if split not in self.valid_splits: + msg = "Unknown split {} .".format(split) + msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits)) + raise ValueError(msg) + return split + + @property + def valid_splits(self): + return 'train', 'val' + + @property + def split_folder(self): + return os.path.join(self.root, self.split) + + def __repr__(self): + head = "Dataset " + self.__class__.__name__ + body = ["Number of datapoints: {}".format(self.__len__())] + if self.root is not None: + body.append("Root location: {}".format(self.root)) + body += ["Split: {}".format(self.split)] + if hasattr(self, 'transform') and self.transform is not None: + body += self._format_transform_repr(self.transform, + "Transforms: ") + if hasattr(self, 'target_transform') and self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, + "Target transforms: ") + lines = [head] + [" " * 4 + line for line in body] + return '\n'.join(lines) + + def _format_transform_repr(self, transform, head): + lines = transform.__repr__().splitlines() + return (["{}{}".format(head, lines[0])] + + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + + +def extract_tar(src, dest=None, gzip=None, delete=False): + import tarfile + + if dest is None: + dest = os.path.dirname(src) + if gzip is None: + gzip = src.lower().endswith('.gz') + + mode = 'r:gz' if gzip else 'r' + with tarfile.open(src, mode) as tarfh: + tarfh.extractall(path=dest) + + if delete: + os.remove(src) + + +def download_and_extract_tar(url, download_root, extract_root=None, filename=None, + md5=None, **kwargs): + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = extract_root + if filename is None: + filename = os.path.basename(url) + + if not check_integrity(os.path.join(download_root, filename), md5): + download_url(url, download_root, filename=filename, md5=md5) + + extract_tar(os.path.join(download_root, filename), extract_root, **kwargs) + + +def parse_devkit(root): + idx_to_wnid, wnid_to_classes = parse_meta(root) + val_idcs = parse_val_groundtruth(root) + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] + return wnid_to_classes, val_wnids + + +def parse_meta(devkit_root, path='data', filename='meta.mat'): + import scipy.io as sio + + metafile = os.path.join(devkit_root, path, filename) + meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] + nums_children = list(zip(*meta))[4] + meta = [meta[idx] for idx, num_children in enumerate(nums_children) + if num_children == 0] + idcs, wnids, classes = list(zip(*meta))[:3] + classes = [tuple(clss.split(', ')) for clss in classes] + idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} + wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} + return idx_to_wnid, wnid_to_classes + + +def parse_val_groundtruth(devkit_root, path='data', + filename='ILSVRC2012_validation_ground_truth.txt'): + with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: + val_idcs = txtfh.readlines() + return [int(val_idx) for val_idx in val_idcs] + + +def prepare_train_folder(folder): + for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: + extract_tar(archive, os.path.splitext(archive)[0], delete=True) + + +def prepare_val_folder(folder, wnids): + img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) + + for wnid in set(wnids): + os.mkdir(os.path.join(folder, wnid)) + + for wnid, img_file in zip(wnids, img_files): + shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) + + +def _splitexts(root): + exts = [] + ext = '.' + while ext: + root, ext = os.path.splitext(root) + exts.append(ext) + return root, ''.join(reversed(exts))