Skip to content

EMNIST dataset + speedup *MNIST preprocessing #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ Fashion-MNIST

.. autoclass:: FashionMNIST

EMNIST
~~~~~~

.. autoclass:: EMNIST

COCO
~~~~

Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST, FashionMNIST
from .mnist import MNIST, EMNIST, FashionMNIST
from .svhn import SVHN
from .phototour import PhotoTour
from .fakedata import FakeData
Expand All @@ -12,5 +12,5 @@
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'FashionMNIST',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')
124 changes: 103 additions & 21 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import os.path
import errno
import numpy as np
import torch
import codecs

Expand Down Expand Up @@ -163,24 +164,115 @@ class FashionMNIST(MNIST):
]


def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
class EMNIST(MNIST):
"""`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.

Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
which one to use.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')

def __init__(self, root, split, **kwargs):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.split = split
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs)

def _training_file(self, split):
return 'training_{}.pt'.format(split)

def _test_file(self, split):
return 'test_{}.pt'.format(split)

def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
import shutil
import zipfile

def parse_byte(b):
if isinstance(b, str):
return ord(b)
return b
if self._check_exists():
return

# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

print('Downloading ' + self.url)
data = urllib.request.urlopen(self.url)
filename = self.url.rpartition('/')[2]
raw_folder = os.path.join(self.root, self.raw_folder)
file_path = os.path.join(raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())

print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(raw_folder)
os.unlink(file_path)
gzip_folder = os.path.join(raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
print('Extracting ' + gzip_file)
with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
out_f.write(zip_f.read())
shutil.rmtree(gzip_folder)

# process and save as torch files
for split in self.splits:
print('Processing ' + split)
training_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
)
test_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
)
with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f:
torch.save(test_set, f)

print('Done!')


def get_int(b):
return int(codecs.encode(b, 'hex'), 16)


def read_label_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
labels = [parse_byte(b) for b in data[8:]]
assert len(labels) == length
return torch.LongTensor(labels)
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
return torch.from_numpy(parsed).view(length).long()


def read_image_file(path):
Expand All @@ -191,15 +283,5 @@ def read_image_file(path):
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
idx = 16
for l in range(length):
img = []
images.append(img)
for r in range(num_rows):
row = []
img.append(row)
for c in range(num_cols):
row.append(parse_byte(data[idx]))
idx += 1
assert len(images) == length
return torch.ByteTensor(images).view(-1, 28, 28)
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
return torch.from_numpy(parsed).view(length, num_rows, num_cols)