From 754d526fe1cc6356eeefb1573b8e1a84f41361dd Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Thu, 10 Nov 2016 12:00:04 -0500 Subject: [PATCH] cifar 10 and 100 --- .gitignore | 7 ++ README.md | 10 ++ test/cifar.py | 12 +++ torchvision/datasets/__init__.py | 4 +- torchvision/datasets/cifar.py | 159 +++++++++++++++++++++++++++++++ 5 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 test/cifar.py create mode 100644 torchvision/datasets/cifar.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000000..11689d25a98 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +build/ +dist/ +torchvision.egg-info/ +*/**/__pycache__ +*/**/*.pyc +*/**/*~ +*~ \ No newline at end of file diff --git a/README.md b/README.md index 6d22155bcc1..495310f4cbd 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ The following dataset loaders are available: - [LSUN Classification](#lsun) - [ImageFolder](#imagefolder) - [Imagenet-12](#imagenet-12) +- [CIFAR10 and CIFAR100](#cifar) Datasets have the API: - `__getitem__` @@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background'] - ['bedroom_train', 'church_train', ...] : a list of categories to load +### CIFAR + +`dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)` +`dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)` + +- `root` : root directory of dataset where there is folder `cifar-10-batches-py` +- `train` : `True` = Training set, `False` = Test set +- `download` : `True` = downloads the dataset from the internet and puts it in root directory. If dataset already downloaded, does not do anything. + ### ImageFolder A generic data loader where the images are arranged in this way: diff --git a/test/cifar.py b/test/cifar.py new file mode 100644 index 00000000000..daf542fa800 --- /dev/null +++ b/test/cifar.py @@ -0,0 +1,12 @@ +import torch +import torchvision.datasets as dset + +print('\n\nCifar 10') +a = dset.CIFAR10(root="abc/def/ghi", download=True) + +print(a[3]) + +print('\n\nCifar 100') +a = dset.CIFAR100(root="abc/def/ghi", download=True) + +print(a[3]) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 45fb8509f22..2eac78c79c0 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,7 +1,9 @@ from .lsun import LSUN, LSUNClass from .folder import ImageFolder from .coco import CocoCaptions, CocoDetection +from .cifar import CIFAR10, CIFAR100 __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', - 'CocoCaptions', 'CocoDetection') + 'CocoCaptions', 'CocoDetection', + 'CIFAR10', 'CIFAR100') diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py new file mode 100644 index 00000000000..9768dbb03c8 --- /dev/null +++ b/torchvision/datasets/cifar.py @@ -0,0 +1,159 @@ +from __future__ import print_function +import torch.utils.data as data +from PIL import Image +import os +import os.path +import errno +import numpy as np +import sys +if sys.version_info[0] == 2: + import cPickle as pickle +else: + import pickle + +class CIFAR10(data.Dataset): + base_folder = 'cifar-10-batches-py' + url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + filename = "cifar-10-python.tar.gz" + tgz_mdf = 'c58f30108f718f92721af3b95e74349a' + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + + def __init__(self, root, train=True, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + self.train = train # training set or test set + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + # now load the picked numpy arrays + self.train_data = [] + self.train_labels = [] + for fentry in self.train_list: + f = fentry[0] + file = os.path.join(root, self.base_folder, f) + fo = open(file, 'rb') + entry = pickle.load(fo) + self.train_data.append(entry['data']) + if 'labels' in entry: + self.train_labels += entry['labels'] + else: + self.train_labels += entry['fine_labels'] + fo.close() + + self.train_data = np.concatenate(self.train_data) + + f = self.test_list[0][0] + file = os.path.join(root, self.base_folder, f) + fo = open(file, 'rb') + entry = pickle.load(fo) + self.test_data = entry['data'] + if 'labels' in entry: + self.test_labels = entry['labels'] + else: + self.test_labels = entry['fine_labels'] + fo.close() + + self.train_data = self.train_data.reshape((50000, 3, 32, 32)) + self.test_data = self.test_data.reshape((10000, 3, 32, 32)) + + def __getitem__(self, index): + if self.train: + img, target = self.train_data[index], self.train_labels[index] + else: + img, target = self.test_data[index], self.test_labels[index] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + if self.train: + return 50000 + else: + return 10000 + + def _check_integrity(self): + import hashlib + root = self.root + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = os.path.join(root, self.base_folder, filename) + if not os.path.isfile(fpath): + return False + md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest() + if md5c != md5: + return False + return True + + def download(self): + from six.moves import urllib + import tarfile + import hashlib + + root = self.root + fpath = os.path.join(root, self.filename) + + try: + os.makedirs(root) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + # downloads file + if os.path.isfile(fpath) and \ + hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5: + print('Using downloaded file: ' + fpath) + else: + print('Downloading ' + self.url + ' to ' + fpath) + urllib.request.urlretrieve(self.url, fpath) + + # extract file + cwd = os.getcwd() + print('Extracting tar file') + tar = tarfile.open(fpath, "r:gz") + os.chdir(root) + tar.extractall() + tar.close() + os.chdir(cwd) + print('Done!') + + +class CIFAR100(CIFAR10): + base_folder = 'cifar-100-python' + url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + filename = "cifar-100-python.tar.gz" + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'], + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ] +