Skip to content

Commit b63a576

Browse files
author
Martin Raison
committed
EMNIST dataset + speedup *MNIST preprocessing
1 parent c31c3d7 commit b63a576

File tree

3 files changed

+110
-23
lines changed

3 files changed

+110
-23
lines changed

docs/source/datasets.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ Fashion-MNIST
3535

3636
.. autoclass:: FashionMNIST
3737

38+
EMNIST
39+
~~~~~~
40+
41+
.. autoclass:: EMNIST
42+
3843
COCO
3944
~~~~
4045

torchvision/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .coco import CocoCaptions, CocoDetection
44
from .cifar import CIFAR10, CIFAR100
55
from .stl10 import STL10
6-
from .mnist import MNIST, FashionMNIST
6+
from .mnist import MNIST, EMNIST, FashionMNIST
77
from .svhn import SVHN
88
from .phototour import PhotoTour
99
from .fakedata import FakeData
@@ -12,5 +12,5 @@
1212
__all__ = ('LSUN', 'LSUNClass',
1313
'ImageFolder', 'FakeData',
1414
'CocoCaptions', 'CocoDetection',
15-
'CIFAR10', 'CIFAR100', 'FashionMNIST',
15+
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
1616
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')

torchvision/datasets/mnist.py

Lines changed: 103 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import os.path
66
import errno
7+
import numpy as np
78
import torch
89
import codecs
910

@@ -163,24 +164,115 @@ class FashionMNIST(MNIST):
163164
]
164165

165166

166-
def get_int(b):
167-
return int(codecs.encode(b, 'hex'), 16)
167+
class EMNIST(MNIST):
168+
"""`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.
169+
170+
Args:
171+
root (string): Root directory of dataset where ``processed/training.pt``
172+
and ``processed/test.pt`` exist.
173+
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
174+
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
175+
which one to use.
176+
train (bool, optional): If True, creates dataset from ``training.pt``,
177+
otherwise from ``test.pt``.
178+
download (bool, optional): If true, downloads the dataset from the internet and
179+
puts it in root directory. If dataset is already downloaded, it is not
180+
downloaded again.
181+
transform (callable, optional): A function/transform that takes in an PIL image
182+
and returns a transformed version. E.g, ``transforms.RandomCrop``
183+
target_transform (callable, optional): A function/transform that takes in the
184+
target and transforms it.
185+
"""
186+
url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
187+
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
188+
189+
def __init__(self, root, split, **kwargs):
190+
if split not in self.splits:
191+
raise RuntimeError('Split "{}" not found. Valid splits are: {}'.format(
192+
split, ', '.join(self.splits),
193+
))
194+
self.split = split
195+
self.training_file = self._training_file(split)
196+
self.test_file = self._test_file(split)
197+
super(EMNIST, self).__init__(root, **kwargs)
168198

199+
def _training_file(self, split):
200+
return 'training_{}.pt'.format(split)
201+
202+
def _test_file(self, split):
203+
return 'test_{}.pt'.format(split)
204+
205+
def download(self):
206+
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
207+
from six.moves import urllib
208+
import gzip
209+
import shutil
210+
import zipfile
169211

170-
def parse_byte(b):
171-
if isinstance(b, str):
172-
return ord(b)
173-
return b
212+
if self._check_exists():
213+
return
214+
215+
# download files
216+
try:
217+
os.makedirs(os.path.join(self.root, self.raw_folder))
218+
os.makedirs(os.path.join(self.root, self.processed_folder))
219+
except OSError as e:
220+
if e.errno == errno.EEXIST:
221+
pass
222+
else:
223+
raise
224+
225+
print('Downloading ' + self.url)
226+
data = urllib.request.urlopen(self.url)
227+
filename = self.url.rpartition('/')[2]
228+
raw_folder = os.path.join(self.root, self.raw_folder)
229+
file_path = os.path.join(raw_folder, filename)
230+
with open(file_path, 'wb') as f:
231+
f.write(data.read())
232+
233+
print('Extracting zip archive')
234+
with zipfile.ZipFile(file_path) as zip_f:
235+
zip_f.extractall(raw_folder)
236+
os.unlink(file_path)
237+
gzip_folder = os.path.join(raw_folder, 'gzip')
238+
for gzip_file in os.listdir(gzip_folder):
239+
if gzip_file.endswith('.gz'):
240+
print('Extracting ' + gzip_file)
241+
with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
242+
gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
243+
out_f.write(zip_f.read())
244+
shutil.rmtree(gzip_folder)
245+
246+
# process and save as torch files
247+
for split in self.splits:
248+
print('Processing ' + split)
249+
training_set = (
250+
read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
251+
read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
252+
)
253+
test_set = (
254+
read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
255+
read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
256+
)
257+
with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f:
258+
torch.save(training_set, f)
259+
with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f:
260+
torch.save(test_set, f)
261+
262+
print('Done!')
263+
264+
265+
def get_int(b):
266+
return int(codecs.encode(b, 'hex'), 16)
174267

175268

176269
def read_label_file(path):
177270
with open(path, 'rb') as f:
178271
data = f.read()
179272
assert get_int(data[:4]) == 2049
180273
length = get_int(data[4:8])
181-
labels = [parse_byte(b) for b in data[8:]]
182-
assert len(labels) == length
183-
return torch.LongTensor(labels)
274+
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
275+
return torch.from_numpy(parsed).view(length).long()
184276

185277

186278
def read_image_file(path):
@@ -191,15 +283,5 @@ def read_image_file(path):
191283
num_rows = get_int(data[8:12])
192284
num_cols = get_int(data[12:16])
193285
images = []
194-
idx = 16
195-
for l in range(length):
196-
img = []
197-
images.append(img)
198-
for r in range(num_rows):
199-
row = []
200-
img.append(row)
201-
for c in range(num_cols):
202-
row.append(parse_byte(data[idx]))
203-
idx += 1
204-
assert len(images) == length
205-
return torch.ByteTensor(images).view(-1, 28, 28)
286+
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
287+
return torch.from_numpy(parsed).view(length, num_rows, num_cols)

0 commit comments

Comments
 (0)