Skip to content

Commit 6938291

Browse files
Philip Meierfmassa
Philip Meier
authored andcommitted
ImageNet dataset (#764)
* initial commit * fixed Python2 issue * fixed naming incorrectness and Python2 compability * fixed preparation of train folder * removed detection dataset * added docstring and repr * moved import of scipy to make the import of torchvision independent of it * improved conversion from class string to index * removed support for other years than 2012 * removed accidentally added file * moved emptying of split folder to avoid accidental deletion * removed deletion of the images * removed error conversion for Python2 * Aligned class indices with the indices identified by ImageFolder class
1 parent 71322cb commit 6938291

File tree

2 files changed

+238
-1
lines changed

2 files changed

+238
-1
lines changed

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .flickr import Flickr8k, Flickr30k
1414
from .voc import VOCSegmentation, VOCDetection
1515
from .cityscapes import Cityscapes
16+
from .imagenet import ImageNet
1617
from .caltech import Caltech101, Caltech256
1718
from .celeba import CelebA
1819

@@ -22,5 +23,5 @@
2223
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
2324
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
2425
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
25-
'VOCSegmentation', 'VOCDetection', 'Cityscapes',
26+
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
2627
'Caltech101', 'Caltech256', 'CelebA')

torchvision/datasets/imagenet.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
from __future__ import print_function
2+
import os
3+
import shutil
4+
import torch
5+
from .folder import ImageFolder
6+
from .utils import check_integrity, download_url
7+
8+
ARCHIVE_DICT = {
9+
'train': {
10+
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
11+
'md5': '1d675b47d978889d74fa0da5fadfb00e',
12+
},
13+
'val': {
14+
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
15+
'md5': '29b22e2961454d5413ddabcf34fc5622',
16+
},
17+
'devkit': {
18+
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
19+
'md5': 'fa75699e90414af021442c21a62c3abf',
20+
}
21+
}
22+
23+
META_DICT = {
24+
'filename': 'meta.bin',
25+
'md5': '7e0d3cf156177e4fc47011cdd30ce706',
26+
}
27+
28+
29+
class ImageNet(ImageFolder):
30+
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
31+
32+
Args:
33+
root (string): Root directory of the ImageNet Dataset.
34+
split (string, optional): The dataset split, supports ``train``, or ``val``.
35+
download (bool, optional): If true, downloads the dataset from the internet and
36+
puts it in root directory. If dataset is already downloaded, it is not
37+
downloaded again.
38+
transform (callable, optional): A function/transform that takes in an PIL image
39+
and returns a transformed version. E.g, ``transforms.RandomCrop``
40+
target_transform (callable, optional): A function/transform that takes in the
41+
target and transforms it.
42+
loader (callable, optional): A function to load an image given its path.
43+
44+
Attributes:
45+
classes (list): List of the class names.
46+
class_to_idx (dict): Dict with items (class_name, class_index).
47+
wnids (list): List of the WordNet IDs.
48+
class_to_idx (dict): Dict with items (wordnet_id, wordnet_id_index).
49+
imgs (list): List of (image path, class_index) tuples
50+
targets (list): The class_index value for each image in the dataset
51+
"""
52+
53+
def __init__(self, root, split='train', download=False, **kwargs):
54+
root = self.root = os.path.expanduser(root)
55+
self.split = self._verify_split(split)
56+
57+
if download:
58+
self.download()
59+
wnid_to_classes = self._load_meta_file()[0]
60+
61+
super(ImageNet, self).__init__(self.split_folder, **kwargs)
62+
self.root = root
63+
64+
idcs = [idx for _, idx in self.imgs]
65+
self.wnids = self.classes
66+
self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)}
67+
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
68+
self.class_to_idx = {cls: idx
69+
for clss, idx in zip(self.classes, idcs)
70+
for cls in clss}
71+
72+
def download(self):
73+
if not self._check_meta_file_integrity():
74+
tmpdir = os.path.join(self.root, 'tmp')
75+
76+
archive_dict = ARCHIVE_DICT['devkit']
77+
download_and_extract_tar(archive_dict['url'], self.root,
78+
extract_root=tmpdir,
79+
md5=archive_dict['md5'])
80+
devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
81+
meta = parse_devkit(os.path.join(tmpdir, devkit_folder))
82+
self._save_meta_file(*meta)
83+
84+
shutil.rmtree(tmpdir)
85+
86+
if not os.path.isdir(self.split_folder):
87+
archive_dict = ARCHIVE_DICT[self.split]
88+
download_and_extract_tar(archive_dict['url'], self.root,
89+
extract_root=self.split_folder,
90+
md5=archive_dict['md5'])
91+
92+
if self.split == 'train':
93+
prepare_train_folder(self.split_folder)
94+
elif self.split == 'val':
95+
val_wnids = self._load_meta_file()[1]
96+
prepare_val_folder(self.split_folder, val_wnids)
97+
else:
98+
msg = ("You set download=True, but a folder '{}' already exist in "
99+
"the root directory. If you want to re-download or re-extract the "
100+
"archive, delete the folder.")
101+
print(msg.format(self.split))
102+
103+
@property
104+
def meta_file(self):
105+
return os.path.join(self.root, META_DICT['filename'])
106+
107+
def _check_meta_file_integrity(self):
108+
return check_integrity(self.meta_file, META_DICT['md5'])
109+
110+
def _load_meta_file(self):
111+
if self._check_meta_file_integrity():
112+
return torch.load(self.meta_file)
113+
else:
114+
raise RuntimeError("Meta file not found or corrupted.",
115+
"You can use download=True to create it.")
116+
117+
def _save_meta_file(self, wnid_to_class, val_wnids):
118+
torch.save((wnid_to_class, val_wnids), self.meta_file)
119+
120+
def _verify_split(self, split):
121+
if split not in self.valid_splits:
122+
msg = "Unknown split {} .".format(split)
123+
msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
124+
raise ValueError(msg)
125+
return split
126+
127+
@property
128+
def valid_splits(self):
129+
return 'train', 'val'
130+
131+
@property
132+
def split_folder(self):
133+
return os.path.join(self.root, self.split)
134+
135+
def __repr__(self):
136+
head = "Dataset " + self.__class__.__name__
137+
body = ["Number of datapoints: {}".format(self.__len__())]
138+
if self.root is not None:
139+
body.append("Root location: {}".format(self.root))
140+
body += ["Split: {}".format(self.split)]
141+
if hasattr(self, 'transform') and self.transform is not None:
142+
body += self._format_transform_repr(self.transform,
143+
"Transforms: ")
144+
if hasattr(self, 'target_transform') and self.target_transform is not None:
145+
body += self._format_transform_repr(self.target_transform,
146+
"Target transforms: ")
147+
lines = [head] + [" " * 4 + line for line in body]
148+
return '\n'.join(lines)
149+
150+
def _format_transform_repr(self, transform, head):
151+
lines = transform.__repr__().splitlines()
152+
return (["{}{}".format(head, lines[0])] +
153+
["{}{}".format(" " * len(head), line) for line in lines[1:]])
154+
155+
156+
def extract_tar(src, dest=None, gzip=None, delete=False):
157+
import tarfile
158+
159+
if dest is None:
160+
dest = os.path.dirname(src)
161+
if gzip is None:
162+
gzip = src.lower().endswith('.gz')
163+
164+
mode = 'r:gz' if gzip else 'r'
165+
with tarfile.open(src, mode) as tarfh:
166+
tarfh.extractall(path=dest)
167+
168+
if delete:
169+
os.remove(src)
170+
171+
172+
def download_and_extract_tar(url, download_root, extract_root=None, filename=None,
173+
md5=None, **kwargs):
174+
download_root = os.path.expanduser(download_root)
175+
if extract_root is None:
176+
extract_root = extract_root
177+
if filename is None:
178+
filename = os.path.basename(url)
179+
180+
if not check_integrity(os.path.join(download_root, filename), md5):
181+
download_url(url, download_root, filename=filename, md5=md5)
182+
183+
extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
184+
185+
186+
def parse_devkit(root):
187+
idx_to_wnid, wnid_to_classes = parse_meta(root)
188+
val_idcs = parse_val_groundtruth(root)
189+
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
190+
return wnid_to_classes, val_wnids
191+
192+
193+
def parse_meta(devkit_root, path='data', filename='meta.mat'):
194+
import scipy.io as sio
195+
196+
metafile = os.path.join(devkit_root, path, filename)
197+
meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
198+
nums_children = list(zip(*meta))[4]
199+
meta = [meta[idx] for idx, num_children in enumerate(nums_children)
200+
if num_children == 0]
201+
idcs, wnids, classes = list(zip(*meta))[:3]
202+
classes = [tuple(clss.split(', ')) for clss in classes]
203+
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
204+
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
205+
return idx_to_wnid, wnid_to_classes
206+
207+
208+
def parse_val_groundtruth(devkit_root, path='data',
209+
filename='ILSVRC2012_validation_ground_truth.txt'):
210+
with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
211+
val_idcs = txtfh.readlines()
212+
return [int(val_idx) for val_idx in val_idcs]
213+
214+
215+
def prepare_train_folder(folder):
216+
for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
217+
extract_tar(archive, os.path.splitext(archive)[0], delete=True)
218+
219+
220+
def prepare_val_folder(folder, wnids):
221+
img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
222+
223+
for wnid in set(wnids):
224+
os.mkdir(os.path.join(folder, wnid))
225+
226+
for wnid, img_file in zip(wnids, img_files):
227+
shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
228+
229+
230+
def _splitexts(root):
231+
exts = []
232+
ext = '.'
233+
while ext:
234+
root, ext = os.path.splitext(root)
235+
exts.append(ext)
236+
return root, ''.join(reversed(exts))

0 commit comments

Comments
 (0)