diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index a4a7fe57ad6..46976051f31 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -35,6 +35,11 @@ Fashion-MNIST .. autoclass:: FashionMNIST +EMNIST +~~~~~~ + +.. autoclass:: EMNIST + COCO ~~~~ diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 9fab55190cc..5a478aa0eff 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -3,7 +3,7 @@ from .coco import CocoCaptions, CocoDetection from .cifar import CIFAR10, CIFAR100 from .stl10 import STL10 -from .mnist import MNIST, FashionMNIST +from .mnist import MNIST, EMNIST, FashionMNIST from .svhn import SVHN from .phototour import PhotoTour from .fakedata import FakeData @@ -12,5 +12,5 @@ __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', - 'CIFAR10', 'CIFAR100', 'FashionMNIST', + 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION') diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 44ade5d4378..8ff6b574eda 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -4,6 +4,7 @@ import os import os.path import errno +import numpy as np import torch import codecs @@ -163,14 +164,106 @@ class FashionMNIST(MNIST): ] -def get_int(b): - return int(codecs.encode(b, 'hex'), 16) +class EMNIST(MNIST): + """`EMNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``processed/training.pt`` + and ``processed/test.pt`` exist. + split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, + ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies + which one to use. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip' + splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') + + def __init__(self, root, split, **kwargs): + if split not in self.splits: + raise ValueError('Split "{}" not found. Valid splits are: {}'.format( + split, ', '.join(self.splits), + )) + self.split = split + self.training_file = self._training_file(split) + self.test_file = self._test_file(split) + super(EMNIST, self).__init__(root, **kwargs) + def _training_file(self, split): + return 'training_{}.pt'.format(split) + + def _test_file(self, split): + return 'test_{}.pt'.format(split) + + def download(self): + """Download the EMNIST data if it doesn't exist in processed_folder already.""" + from six.moves import urllib + import gzip + import shutil + import zipfile -def parse_byte(b): - if isinstance(b, str): - return ord(b) - return b + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + print('Downloading ' + self.url) + data = urllib.request.urlopen(self.url) + filename = self.url.rpartition('/')[2] + raw_folder = os.path.join(self.root, self.raw_folder) + file_path = os.path.join(raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + + print('Extracting zip archive') + with zipfile.ZipFile(file_path) as zip_f: + zip_f.extractall(raw_folder) + os.unlink(file_path) + gzip_folder = os.path.join(raw_folder, 'gzip') + for gzip_file in os.listdir(gzip_folder): + if gzip_file.endswith('.gz'): + print('Extracting ' + gzip_file) + with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \ + gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f: + out_f.write(zip_f.read()) + shutil.rmtree(gzip_folder) + + # process and save as torch files + for split in self.splits: + print('Processing ' + split) + training_set = ( + read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), + read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) + ) + test_set = ( + read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), + read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) + ) + with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f: + torch.save(training_set, f) + with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f: + torch.save(test_set, f) + + print('Done!') + + +def get_int(b): + return int(codecs.encode(b, 'hex'), 16) def read_label_file(path): @@ -178,9 +271,8 @@ def read_label_file(path): data = f.read() assert get_int(data[:4]) == 2049 length = get_int(data[4:8]) - labels = [parse_byte(b) for b in data[8:]] - assert len(labels) == length - return torch.LongTensor(labels) + parsed = np.frombuffer(data, dtype=np.uint8, offset=8) + return torch.from_numpy(parsed).view(length).long() def read_image_file(path): @@ -191,15 +283,5 @@ def read_image_file(path): num_rows = get_int(data[8:12]) num_cols = get_int(data[12:16]) images = [] - idx = 16 - for l in range(length): - img = [] - images.append(img) - for r in range(num_rows): - row = [] - img.append(row) - for c in range(num_cols): - row.append(parse_byte(data[idx])) - idx += 1 - assert len(images) == length - return torch.ByteTensor(images).view(-1, 28, 28) + parsed = np.frombuffer(data, dtype=np.uint8, offset=16) + return torch.from_numpy(parsed).view(length, num_rows, num_cols)