4
4
import os
5
5
import os .path
6
6
import errno
7
+ import numpy as np
7
8
import torch
8
9
import codecs
9
10
@@ -162,25 +163,115 @@ class FashionMNIST(MNIST):
162
163
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz' ,
163
164
]
164
165
166
+ class EMNIST (MNIST ):
167
+ """`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.
165
168
166
- def get_int (b ):
167
- return int (codecs .encode (b , 'hex' ), 16 )
169
+ Args:
170
+ root (string): Root directory of dataset where ``processed/training.pt``
171
+ and ``processed/test.pt`` exist.
172
+ split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
173
+ ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
174
+ which one to use.
175
+ train (bool, optional): If True, creates dataset from ``training.pt``,
176
+ otherwise from ``test.pt``.
177
+ download (bool, optional): If true, downloads the dataset from the internet and
178
+ puts it in root directory. If dataset is already downloaded, it is not
179
+ downloaded again.
180
+ transform (callable, optional): A function/transform that takes in an PIL image
181
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
182
+ target_transform (callable, optional): A function/transform that takes in the
183
+ target and transforms it.
184
+ """
185
+ url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
186
+ splits = ('byclass' , 'bymerge' , 'balanced' , 'letters' , 'digits' , 'mnist' )
187
+
188
+ def __init__ (self , root , split , ** kwargs ):
189
+ if split not in self .splits :
190
+ raise RuntimeError ('Split "{}" not found. Valid splits are: {}' .format (
191
+ split , ', ' .join (self .splits ),
192
+ ))
193
+ self .split = split
194
+ self .training_file = self ._training_file (split )
195
+ self .test_file = self ._test_file (split )
196
+ super (EMNIST , self ).__init__ (root , ** kwargs )
197
+
198
+ def _training_file (self , split ):
199
+ return 'training_{}.pt' .format (split )
200
+
201
+ def _test_file (self , split ):
202
+ return 'test_{}.pt' .format (split )
203
+
204
+ def download (self ):
205
+ """Download the EMNIST data if it doesn't exist in processed_folder already."""
206
+ from six .moves import urllib
207
+ import gzip
208
+ import shutil
209
+ import zipfile
168
210
211
+ if self ._check_exists ():
212
+ return
169
213
170
- def parse_byte (b ):
171
- if isinstance (b , str ):
172
- return ord (b )
173
- return b
214
+ # download files
215
+ try :
216
+ os .makedirs (os .path .join (self .root , self .raw_folder ))
217
+ os .makedirs (os .path .join (self .root , self .processed_folder ))
218
+ except OSError as e :
219
+ if e .errno == errno .EEXIST :
220
+ pass
221
+ else :
222
+ raise
223
+
224
+ print ('Downloading ' + self .url )
225
+ data = urllib .request .urlopen (self .url )
226
+ filename = self .url .rpartition ('/' )[2 ]
227
+ raw_folder = os .path .join (self .root , self .raw_folder )
228
+ file_path = os .path .join (raw_folder , filename )
229
+ with open (file_path , 'wb' ) as f :
230
+ f .write (data .read ())
231
+
232
+ print ('Extracting zip archive' )
233
+ with zipfile .ZipFile (file_path ) as zip_f :
234
+ zip_f .extractall (raw_folder )
235
+ os .unlink (file_path )
236
+ gzip_folder = os .path .join (raw_folder , 'gzip' )
237
+ for gzip_file in os .listdir (gzip_folder ):
238
+ if gzip_file .endswith ('.gz' ):
239
+ print ('Extracting ' + gzip_file )
240
+ with open (os .path .join (raw_folder , gzip_file .replace ('.gz' , '' )), 'wb' ) as out_f , \
241
+ gzip .GzipFile (os .path .join (gzip_folder , gzip_file )) as zip_f :
242
+ out_f .write (zip_f .read ())
243
+ shutil .rmtree (gzip_folder )
244
+
245
+ # process and save as torch files
246
+ for split in self .splits :
247
+ print ('Processing ' + split )
248
+ training_set = (
249
+ read_image_file (os .path .join (raw_folder , 'emnist-{}-train-images-idx3-ubyte' .format (split ))),
250
+ read_label_file (os .path .join (raw_folder , 'emnist-{}-train-labels-idx1-ubyte' .format (split )))
251
+ )
252
+ test_set = (
253
+ read_image_file (os .path .join (raw_folder , 'emnist-{}-test-images-idx3-ubyte' .format (split ))),
254
+ read_label_file (os .path .join (raw_folder , 'emnist-{}-test-labels-idx1-ubyte' .format (split )))
255
+ )
256
+ with open (os .path .join (self .root , self .processed_folder , self ._training_file (split )), 'wb' ) as f :
257
+ torch .save (training_set , f )
258
+ with open (os .path .join (self .root , self .processed_folder , self ._test_file (split )), 'wb' ) as f :
259
+ torch .save (test_set , f )
260
+
261
+ print ('Done!' )
262
+
263
+
264
+ def get_int (b ):
265
+ return int (codecs .encode (b , 'hex' ), 16 )
174
266
175
267
176
268
def read_label_file (path ):
177
269
with open (path , 'rb' ) as f :
178
270
data = f .read ()
179
271
assert get_int (data [:4 ]) == 2049
180
272
length = get_int (data [4 :8 ])
181
- labels = [parse_byte (b ) for b in data [8 :]]
182
- assert len (labels ) == length
183
- return torch .LongTensor (labels )
273
+ parsed = np .frombuffer (data , dtype = np .uint8 , offset = 8 )
274
+ return torch .from_numpy (parsed ).view (length ).long ()
184
275
185
276
186
277
def read_image_file (path ):
@@ -191,15 +282,5 @@ def read_image_file(path):
191
282
num_rows = get_int (data [8 :12 ])
192
283
num_cols = get_int (data [12 :16 ])
193
284
images = []
194
- idx = 16
195
- for l in range (length ):
196
- img = []
197
- images .append (img )
198
- for r in range (num_rows ):
199
- row = []
200
- img .append (row )
201
- for c in range (num_cols ):
202
- row .append (parse_byte (data [idx ]))
203
- idx += 1
204
- assert len (images ) == length
205
- return torch .ByteTensor (images ).view (- 1 , 28 , 28 )
285
+ parsed = np .frombuffer (data , dtype = np .uint8 , offset = 16 )
286
+ return torch .from_numpy (parsed ).view (length , num_rows , num_cols )
0 commit comments