Skip to content

Commit 33b252b

Browse files
author
Martin Raison
committed
EMNIST dataset + speedup *MNIST preprocessing
1 parent c31c3d7 commit 33b252b

File tree

3 files changed

+109
-23
lines changed

3 files changed

+109
-23
lines changed

docs/source/datasets.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ Fashion-MNIST
3535

3636
.. autoclass:: FashionMNIST
3737

38+
EMNIST
39+
~~~~~~
40+
41+
.. autoclass:: EMNIST
42+
3843
COCO
3944
~~~~
4045

torchvision/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .coco import CocoCaptions, CocoDetection
44
from .cifar import CIFAR10, CIFAR100
55
from .stl10 import STL10
6-
from .mnist import MNIST, FashionMNIST
6+
from .mnist import MNIST, EMNIST, FashionMNIST
77
from .svhn import SVHN
88
from .phototour import PhotoTour
99
from .fakedata import FakeData
@@ -12,5 +12,5 @@
1212
__all__ = ('LSUN', 'LSUNClass',
1313
'ImageFolder', 'FakeData',
1414
'CocoCaptions', 'CocoDetection',
15-
'CIFAR10', 'CIFAR100', 'FashionMNIST',
15+
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
1616
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')

torchvision/datasets/mnist.py

Lines changed: 102 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import os.path
66
import errno
7+
import numpy as np
78
import torch
89
import codecs
910

@@ -162,25 +163,115 @@ class FashionMNIST(MNIST):
162163
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
163164
]
164165

166+
class EMNIST(MNIST):
167+
"""`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.
165168
166-
def get_int(b):
167-
return int(codecs.encode(b, 'hex'), 16)
169+
Args:
170+
root (string): Root directory of dataset where ``processed/training.pt``
171+
and ``processed/test.pt`` exist.
172+
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
173+
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
174+
which one to use.
175+
train (bool, optional): If True, creates dataset from ``training.pt``,
176+
otherwise from ``test.pt``.
177+
download (bool, optional): If true, downloads the dataset from the internet and
178+
puts it in root directory. If dataset is already downloaded, it is not
179+
downloaded again.
180+
transform (callable, optional): A function/transform that takes in an PIL image
181+
and returns a transformed version. E.g, ``transforms.RandomCrop``
182+
target_transform (callable, optional): A function/transform that takes in the
183+
target and transforms it.
184+
"""
185+
url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
186+
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
187+
188+
def __init__(self, root, split, **kwargs):
189+
if split not in self.splits:
190+
raise RuntimeError('Split "{}" not found. Valid splits are: {}'.format(
191+
split, ', '.join(self.splits),
192+
))
193+
self.split = split
194+
self.training_file = self._training_file(split)
195+
self.test_file = self._test_file(split)
196+
super(EMNIST, self).__init__(root, **kwargs)
197+
198+
def _training_file(self, split):
199+
return 'training_{}.pt'.format(split)
200+
201+
def _test_file(self, split):
202+
return 'test_{}.pt'.format(split)
203+
204+
def download(self):
205+
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
206+
from six.moves import urllib
207+
import gzip
208+
import shutil
209+
import zipfile
168210

211+
if self._check_exists():
212+
return
169213

170-
def parse_byte(b):
171-
if isinstance(b, str):
172-
return ord(b)
173-
return b
214+
# download files
215+
try:
216+
os.makedirs(os.path.join(self.root, self.raw_folder))
217+
os.makedirs(os.path.join(self.root, self.processed_folder))
218+
except OSError as e:
219+
if e.errno == errno.EEXIST:
220+
pass
221+
else:
222+
raise
223+
224+
print('Downloading ' + self.url)
225+
data = urllib.request.urlopen(self.url)
226+
filename = self.url.rpartition('/')[2]
227+
raw_folder = os.path.join(self.root, self.raw_folder)
228+
file_path = os.path.join(raw_folder, filename)
229+
with open(file_path, 'wb') as f:
230+
f.write(data.read())
231+
232+
print('Extracting zip archive')
233+
with zipfile.ZipFile(file_path) as zip_f:
234+
zip_f.extractall(raw_folder)
235+
os.unlink(file_path)
236+
gzip_folder = os.path.join(raw_folder, 'gzip')
237+
for gzip_file in os.listdir(gzip_folder):
238+
if gzip_file.endswith('.gz'):
239+
print('Extracting ' + gzip_file)
240+
with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
241+
gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
242+
out_f.write(zip_f.read())
243+
shutil.rmtree(gzip_folder)
244+
245+
# process and save as torch files
246+
for split in self.splits:
247+
print('Processing ' + split)
248+
training_set = (
249+
read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
250+
read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
251+
)
252+
test_set = (
253+
read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
254+
read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
255+
)
256+
with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f:
257+
torch.save(training_set, f)
258+
with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f:
259+
torch.save(test_set, f)
260+
261+
print('Done!')
262+
263+
264+
def get_int(b):
265+
return int(codecs.encode(b, 'hex'), 16)
174266

175267

176268
def read_label_file(path):
177269
with open(path, 'rb') as f:
178270
data = f.read()
179271
assert get_int(data[:4]) == 2049
180272
length = get_int(data[4:8])
181-
labels = [parse_byte(b) for b in data[8:]]
182-
assert len(labels) == length
183-
return torch.LongTensor(labels)
273+
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
274+
return torch.from_numpy(parsed).view(length).long()
184275

185276

186277
def read_image_file(path):
@@ -191,15 +282,5 @@ def read_image_file(path):
191282
num_rows = get_int(data[8:12])
192283
num_cols = get_int(data[12:16])
193284
images = []
194-
idx = 16
195-
for l in range(length):
196-
img = []
197-
images.append(img)
198-
for r in range(num_rows):
199-
row = []
200-
img.append(row)
201-
for c in range(num_cols):
202-
row.append(parse_byte(data[idx]))
203-
idx += 1
204-
assert len(images) == length
205-
return torch.ByteTensor(images).view(-1, 28, 28)
285+
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
286+
return torch.from_numpy(parsed).view(length, num_rows, num_cols)

0 commit comments

Comments
 (0)