4
4
import os
5
5
import os .path
6
6
import errno
7
+ import numpy as np
7
8
import torch
8
9
import codecs
9
10
@@ -163,24 +164,115 @@ class FashionMNIST(MNIST):
163
164
]
164
165
165
166
166
- def get_int (b ):
167
- return int (codecs .encode (b , 'hex' ), 16 )
167
+ class EMNIST (MNIST ):
168
+ """`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.
169
+
170
+ Args:
171
+ root (string): Root directory of dataset where ``processed/training.pt``
172
+ and ``processed/test.pt`` exist.
173
+ split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
174
+ ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
175
+ which one to use.
176
+ train (bool, optional): If True, creates dataset from ``training.pt``,
177
+ otherwise from ``test.pt``.
178
+ download (bool, optional): If true, downloads the dataset from the internet and
179
+ puts it in root directory. If dataset is already downloaded, it is not
180
+ downloaded again.
181
+ transform (callable, optional): A function/transform that takes in an PIL image
182
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
183
+ target_transform (callable, optional): A function/transform that takes in the
184
+ target and transforms it.
185
+ """
186
+ url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
187
+ splits = ('byclass' , 'bymerge' , 'balanced' , 'letters' , 'digits' , 'mnist' )
188
+
189
+ def __init__ (self , root , split , ** kwargs ):
190
+ if split not in self .splits :
191
+ raise ValueError ('Split "{}" not found. Valid splits are: {}' .format (
192
+ split , ', ' .join (self .splits ),
193
+ ))
194
+ self .split = split
195
+ self .training_file = self ._training_file (split )
196
+ self .test_file = self ._test_file (split )
197
+ super (EMNIST , self ).__init__ (root , ** kwargs )
168
198
199
+ def _training_file (self , split ):
200
+ return 'training_{}.pt' .format (split )
201
+
202
+ def _test_file (self , split ):
203
+ return 'test_{}.pt' .format (split )
204
+
205
+ def download (self ):
206
+ """Download the EMNIST data if it doesn't exist in processed_folder already."""
207
+ from six .moves import urllib
208
+ import gzip
209
+ import shutil
210
+ import zipfile
169
211
170
- def parse_byte (b ):
171
- if isinstance (b , str ):
172
- return ord (b )
173
- return b
212
+ if self ._check_exists ():
213
+ return
214
+
215
+ # download files
216
+ try :
217
+ os .makedirs (os .path .join (self .root , self .raw_folder ))
218
+ os .makedirs (os .path .join (self .root , self .processed_folder ))
219
+ except OSError as e :
220
+ if e .errno == errno .EEXIST :
221
+ pass
222
+ else :
223
+ raise
224
+
225
+ print ('Downloading ' + self .url )
226
+ data = urllib .request .urlopen (self .url )
227
+ filename = self .url .rpartition ('/' )[2 ]
228
+ raw_folder = os .path .join (self .root , self .raw_folder )
229
+ file_path = os .path .join (raw_folder , filename )
230
+ with open (file_path , 'wb' ) as f :
231
+ f .write (data .read ())
232
+
233
+ print ('Extracting zip archive' )
234
+ with zipfile .ZipFile (file_path ) as zip_f :
235
+ zip_f .extractall (raw_folder )
236
+ os .unlink (file_path )
237
+ gzip_folder = os .path .join (raw_folder , 'gzip' )
238
+ for gzip_file in os .listdir (gzip_folder ):
239
+ if gzip_file .endswith ('.gz' ):
240
+ print ('Extracting ' + gzip_file )
241
+ with open (os .path .join (raw_folder , gzip_file .replace ('.gz' , '' )), 'wb' ) as out_f , \
242
+ gzip .GzipFile (os .path .join (gzip_folder , gzip_file )) as zip_f :
243
+ out_f .write (zip_f .read ())
244
+ shutil .rmtree (gzip_folder )
245
+
246
+ # process and save as torch files
247
+ for split in self .splits :
248
+ print ('Processing ' + split )
249
+ training_set = (
250
+ read_image_file (os .path .join (raw_folder , 'emnist-{}-train-images-idx3-ubyte' .format (split ))),
251
+ read_label_file (os .path .join (raw_folder , 'emnist-{}-train-labels-idx1-ubyte' .format (split )))
252
+ )
253
+ test_set = (
254
+ read_image_file (os .path .join (raw_folder , 'emnist-{}-test-images-idx3-ubyte' .format (split ))),
255
+ read_label_file (os .path .join (raw_folder , 'emnist-{}-test-labels-idx1-ubyte' .format (split )))
256
+ )
257
+ with open (os .path .join (self .root , self .processed_folder , self ._training_file (split )), 'wb' ) as f :
258
+ torch .save (training_set , f )
259
+ with open (os .path .join (self .root , self .processed_folder , self ._test_file (split )), 'wb' ) as f :
260
+ torch .save (test_set , f )
261
+
262
+ print ('Done!' )
263
+
264
+
265
+ def get_int (b ):
266
+ return int (codecs .encode (b , 'hex' ), 16 )
174
267
175
268
176
269
def read_label_file (path ):
177
270
with open (path , 'rb' ) as f :
178
271
data = f .read ()
179
272
assert get_int (data [:4 ]) == 2049
180
273
length = get_int (data [4 :8 ])
181
- labels = [parse_byte (b ) for b in data [8 :]]
182
- assert len (labels ) == length
183
- return torch .LongTensor (labels )
274
+ parsed = np .frombuffer (data , dtype = np .uint8 , offset = 8 )
275
+ return torch .from_numpy (parsed ).view (length ).long ()
184
276
185
277
186
278
def read_image_file (path ):
@@ -191,15 +283,5 @@ def read_image_file(path):
191
283
num_rows = get_int (data [8 :12 ])
192
284
num_cols = get_int (data [12 :16 ])
193
285
images = []
194
- idx = 16
195
- for l in range (length ):
196
- img = []
197
- images .append (img )
198
- for r in range (num_rows ):
199
- row = []
200
- img .append (row )
201
- for c in range (num_cols ):
202
- row .append (parse_byte (data [idx ]))
203
- idx += 1
204
- assert len (images ) == length
205
- return torch .ByteTensor (images ).view (- 1 , 28 , 28 )
286
+ parsed = np .frombuffer (data , dtype = np .uint8 , offset = 16 )
287
+ return torch .from_numpy (parsed ).view (length , num_rows , num_cols )
0 commit comments