7
7
8
8
if sys .version_info [0 ] == 2 :
9
9
# FIXME: I don't know if this is good pratice / robust
10
- FileExistsError = OSError
10
+ FileNotFoundError = OSError
11
11
12
12
ARCHIVE_DICT = {
13
- ( '2012' , ' train') : {
13
+ ' train' : {
14
14
'url' : 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar' ,
15
15
'md5' : '1d675b47d978889d74fa0da5fadfb00e' ,
16
16
},
17
- ( '2012' , ' val') : {
17
+ ' val' : {
18
18
'url' : 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar' ,
19
19
'md5' : '29b22e2961454d5413ddabcf34fc5622' ,
20
20
},
21
- ( '2012' , ' devkit') : {
21
+ ' devkit' : {
22
22
'url' : 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz' ,
23
23
'md5' : 'fa75699e90414af021442c21a62c3abf' ,
24
24
}
25
25
}
26
26
27
27
META_DICT = {
28
- '2012' : '5c2648af14b2ff44540504b860a81a79' ,
28
+ 'filename' : 'meta.bin' ,
29
+ 'md5' : '5c2648af14b2ff44540504b860a81a79' ,
29
30
}
30
31
31
- META_FILE = 'meta.bin'
32
-
33
32
34
33
class ImageNet (ImageFolder ):
35
- """`ImageNet <http://image-net.org/>`_ Classification Dataset.
34
+ """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
36
35
37
36
Args:
38
37
root (string): Root directory of the ImageNet Dataset.
39
- year (string, optional): The dataset year, supports years 2012 to 2012.
40
38
split (string, optional): The dataset split, supports ``train``, or ``val``.
41
39
download (bool, optional): If true, downloads the dataset from the internet and
42
40
puts it in root directory. If dataset is already downloaded, it is not
@@ -56,11 +54,10 @@ class ImageNet(ImageFolder):
56
54
targets (list): The class_index value for each image in the dataset
57
55
"""
58
56
59
- def __init__ (self , root , split = 'train' , year = '2012' , download = False , ** kwargs ):
57
+ def __init__ (self , root , split = 'train' , download = False , ** kwargs ):
60
58
61
59
root = self .root = os .path .expanduser (root )
62
60
self .split = self ._verify_split (split )
63
- self .year = self ._verify_year (year )
64
61
65
62
if download :
66
63
self .download ()
@@ -72,13 +69,13 @@ def __init__(self, root, split='train', year='2012', download=False, **kwargs):
72
69
self .class_to_idx = class_to_idx
73
70
74
71
def download (self ):
75
- self ._prepare_tree ()
72
+ self ._empty_split_folder ()
76
73
77
- meta_file = os .path .join (self .year_folder , META_FILE )
78
- if not check_integrity (meta_file , META_DICT [self . year ]):
74
+ meta_file = os .path .join (self .root , META_DICT [ 'filename' ] )
75
+ if not check_integrity (meta_file , META_DICT ['md5' ]):
79
76
tmpdir = os .path .join (self .root , 'tmp' )
80
77
81
- archive_dict = ARCHIVE_DICT [( self . year , 'devkit' ) ]
78
+ archive_dict = ARCHIVE_DICT ['devkit' ]
82
79
download_and_extract_tar (archive_dict ['url' ], self .root ,
83
80
extract_root = tmpdir ,
84
81
md5 = archive_dict ['md5' ])
@@ -88,7 +85,7 @@ def download(self):
88
85
89
86
shutil .rmtree (tmpdir )
90
87
91
- archive_dict = ARCHIVE_DICT [( self .year , self . split ) ]
88
+ archive_dict = ARCHIVE_DICT [self .split ]
92
89
download_and_extract_tar (archive_dict ['url' ], self .root ,
93
90
extract_root = self .split_folder ,
94
91
md5 = archive_dict ['md5' ])
@@ -101,13 +98,14 @@ def download(self):
101
98
102
99
def _load_meta (self ):
103
100
# TODO: verify meta file
104
- return torch .load (os .path .join (self .year_folder , META_FILE ))[0 ]
101
+ return torch .load (os .path .join (self .root , META_DICT [ 'filename' ] ))[0 ]
105
102
106
- def _prepare_tree (self ):
103
+ def _empty_split_folder (self ):
107
104
try :
108
- os .makedirs (self .split_folder )
109
- except FileExistsError :
110
105
shutil .rmtree (self .split_folder )
106
+ except FileNotFoundError :
107
+ pass
108
+ os .makedirs (self .split_folder )
111
109
112
110
def _verify_split (self , split ):
113
111
if split not in self .valid_splits :
@@ -120,36 +118,16 @@ def _verify_split(self, split):
120
118
def valid_splits (self ):
121
119
return 'train' , 'val'
122
120
123
- def _verify_year (self , year ):
124
- if year not in self .valid_years :
125
- msg = "Unknown year {} ." .format (year )
126
- msg += "Valid years are {{}}." .format (", " .join (self .valid_years ))
127
- raise ValueError (msg )
128
- return year
129
-
130
- @property
131
- def valid_years (self ):
132
- return '2012' ,
133
-
134
- @property
135
- def base_folder (self ):
136
- return os .path .join (self .root , 'ILSVRC' )
137
-
138
- @property
139
- def year_folder (self ):
140
- return os .path .join (self .base_folder , self .year )
141
-
142
121
@property
143
122
def split_folder (self ):
144
- return os .path .join (self .year_folder , self .split )
123
+ return os .path .join (self .root , self .split )
145
124
146
125
def __repr__ (self ):
147
126
head = "Dataset " + self .__class__ .__name__
148
127
body = ["Number of datapoints: {}" .format (self .__len__ ())]
149
128
if self .root is not None :
150
129
body .append ("Root location: {}" .format (self .root ))
151
- body += ["Year: {}" .format (self .year ),
152
- "Split: {}" .format (self .split )]
130
+ body += ["Split: {}" .format (self .split )]
153
131
if hasattr (self , 'transform' ) and self .transform is not None :
154
132
body += self ._format_transform_repr (self .transform ,
155
133
"Transforms: " )
@@ -196,7 +174,6 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non
196
174
197
175
198
176
def parse_devkit (root ):
199
- # FIXME: generalize this for all years
200
177
meta = parse_meta (root )
201
178
val_idcs = parse_val_groundtruth (root )
202
179
@@ -208,7 +185,6 @@ def parse_devkit(root):
208
185
209
186
210
187
def parse_meta (devkit_root , path = 'data' , filename = 'meta.mat' ):
211
- # FIXME: generalize this for all years
212
188
import scipy .io as sio
213
189
214
190
metafile = os .path .join (devkit_root , path , filename )
@@ -224,9 +200,8 @@ def parse_meta(devkit_root, path='data', filename='meta.mat'):
224
200
225
201
def parse_val_groundtruth (devkit_root , path = 'data' ,
226
202
filename = 'ILSVRC2012_validation_ground_truth.txt' ):
227
- # FIXME: generalize this for all years
228
- with open (os .path .join (devkit_root , path , filename ), 'r' ) as fh :
229
- val_idcs = fh .readlines ()
203
+ with open (os .path .join (devkit_root , path , filename ), 'r' ) as txtfh :
204
+ val_idcs = txtfh .readlines ()
230
205
return [int (val_idx ) for val_idx in val_idcs ]
231
206
232
207
0 commit comments