Skip to content

Commit 616492e

Browse files
author
Philip Meier
committed
removed support for other years than 2012
1 parent d6983b3 commit 616492e

File tree

2 files changed

+25
-47
lines changed

2 files changed

+25
-47
lines changed

torchvision/datasets/imagenet.py

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,34 @@
77

88
if sys.version_info[0] == 2:
99
# FIXME: I don't know if this is good pratice / robust
10-
FileExistsError = OSError
10+
FileNotFoundError = OSError
1111

1212
ARCHIVE_DICT = {
13-
('2012', 'train'): {
13+
'train': {
1414
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
1515
'md5': '1d675b47d978889d74fa0da5fadfb00e',
1616
},
17-
('2012', 'val'): {
17+
'val': {
1818
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
1919
'md5': '29b22e2961454d5413ddabcf34fc5622',
2020
},
21-
('2012', 'devkit'): {
21+
'devkit': {
2222
'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
2323
'md5': 'fa75699e90414af021442c21a62c3abf',
2424
}
2525
}
2626

2727
META_DICT = {
28-
'2012': '5c2648af14b2ff44540504b860a81a79',
28+
'filename': 'meta.bin',
29+
'md5': '5c2648af14b2ff44540504b860a81a79',
2930
}
3031

31-
META_FILE = 'meta.bin'
32-
3332

3433
class ImageNet(ImageFolder):
35-
"""`ImageNet <http://image-net.org/>`_ Classification Dataset.
34+
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
3635
3736
Args:
3837
root (string): Root directory of the ImageNet Dataset.
39-
year (string, optional): The dataset year, supports years 2012 to 2012.
4038
split (string, optional): The dataset split, supports ``train``, or ``val``.
4139
download (bool, optional): If true, downloads the dataset from the internet and
4240
puts it in root directory. If dataset is already downloaded, it is not
@@ -56,11 +54,10 @@ class ImageNet(ImageFolder):
5654
targets (list): The class_index value for each image in the dataset
5755
"""
5856

59-
def __init__(self, root, split='train', year='2012', download=False, **kwargs):
57+
def __init__(self, root, split='train', download=False, **kwargs):
6058

6159
root = self.root = os.path.expanduser(root)
6260
self.split = self._verify_split(split)
63-
self.year = self._verify_year(year)
6461

6562
if download:
6663
self.download()
@@ -72,13 +69,13 @@ def __init__(self, root, split='train', year='2012', download=False, **kwargs):
7269
self.class_to_idx = class_to_idx
7370

7471
def download(self):
75-
self._prepare_tree()
72+
self._empty_split_folder()
7673

77-
meta_file = os.path.join(self.year_folder, META_FILE)
78-
if not check_integrity(meta_file, META_DICT[self.year]):
74+
meta_file = os.path.join(self.root, META_DICT['filename'])
75+
if not check_integrity(meta_file, META_DICT['md5']):
7976
tmpdir = os.path.join(self.root, 'tmp')
8077

81-
archive_dict = ARCHIVE_DICT[(self.year, 'devkit')]
78+
archive_dict = ARCHIVE_DICT['devkit']
8279
download_and_extract_tar(archive_dict['url'], self.root,
8380
extract_root=tmpdir,
8481
md5=archive_dict['md5'])
@@ -88,7 +85,7 @@ def download(self):
8885

8986
shutil.rmtree(tmpdir)
9087

91-
archive_dict = ARCHIVE_DICT[(self.year, self.split)]
88+
archive_dict = ARCHIVE_DICT[self.split]
9289
download_and_extract_tar(archive_dict['url'], self.root,
9390
extract_root=self.split_folder,
9491
md5=archive_dict['md5'])
@@ -101,13 +98,14 @@ def download(self):
10198

10299
def _load_meta(self):
103100
# TODO: verify meta file
104-
return torch.load(os.path.join(self.year_folder, META_FILE))[0]
101+
return torch.load(os.path.join(self.root, META_DICT['filename']))[0]
105102

106-
def _prepare_tree(self):
103+
def _empty_split_folder(self):
107104
try:
108-
os.makedirs(self.split_folder)
109-
except FileExistsError:
110105
shutil.rmtree(self.split_folder)
106+
except FileNotFoundError:
107+
pass
108+
os.makedirs(self.split_folder)
111109

112110
def _verify_split(self, split):
113111
if split not in self.valid_splits:
@@ -120,36 +118,16 @@ def _verify_split(self, split):
120118
def valid_splits(self):
121119
return 'train', 'val'
122120

123-
def _verify_year(self, year):
124-
if year not in self.valid_years:
125-
msg = "Unknown year {} .".format(year)
126-
msg += "Valid years are {{}}.".format(", ".join(self.valid_years))
127-
raise ValueError(msg)
128-
return year
129-
130-
@property
131-
def valid_years(self):
132-
return '2012',
133-
134-
@property
135-
def base_folder(self):
136-
return os.path.join(self.root, 'ILSVRC')
137-
138-
@property
139-
def year_folder(self):
140-
return os.path.join(self.base_folder, self.year)
141-
142121
@property
143122
def split_folder(self):
144-
return os.path.join(self.year_folder, self.split)
123+
return os.path.join(self.root, self.split)
145124

146125
def __repr__(self):
147126
head = "Dataset " + self.__class__.__name__
148127
body = ["Number of datapoints: {}".format(self.__len__())]
149128
if self.root is not None:
150129
body.append("Root location: {}".format(self.root))
151-
body += ["Year: {}".format(self.year),
152-
"Split: {}".format(self.split)]
130+
body += ["Split: {}".format(self.split)]
153131
if hasattr(self, 'transform') and self.transform is not None:
154132
body += self._format_transform_repr(self.transform,
155133
"Transforms: ")
@@ -196,7 +174,6 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non
196174

197175

198176
def parse_devkit(root):
199-
# FIXME: generalize this for all years
200177
meta = parse_meta(root)
201178
val_idcs = parse_val_groundtruth(root)
202179

@@ -208,7 +185,6 @@ def parse_devkit(root):
208185

209186

210187
def parse_meta(devkit_root, path='data', filename='meta.mat'):
211-
# FIXME: generalize this for all years
212188
import scipy.io as sio
213189

214190
metafile = os.path.join(devkit_root, path, filename)
@@ -224,9 +200,8 @@ def parse_meta(devkit_root, path='data', filename='meta.mat'):
224200

225201
def parse_val_groundtruth(devkit_root, path='data',
226202
filename='ILSVRC2012_validation_ground_truth.txt'):
227-
# FIXME: generalize this for all years
228-
with open(os.path.join(devkit_root, path, filename), 'r') as fh:
229-
val_idcs = fh.readlines()
203+
with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
204+
val_idcs = txtfh.readlines()
230205
return [int(val_idx) for val_idx in val_idcs]
231206

232207

torchvision/tmp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchvision.datasets import ImageNet
2+
3+
a = ImageNet('~/Downloads/ImageNet', download=True)

0 commit comments

Comments
 (0)