diff --git a/.gitignore b/.gitignore index c02a6ab80e3..2abd33a8556 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ torchvision.egg-info/ */**/*.pyc */**/*~ *~ -docs/build \ No newline at end of file +docs/build diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 5a478aa0eff..1cb604a79f6 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -8,9 +8,11 @@ from .phototour import PhotoTour from .fakedata import FakeData from .semeion import SEMEION +from .omniglot import Omniglot __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', - 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION') + 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', + 'Omniglot') diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index c3b7b8ef4f2..d39bb8902d7 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -2,7 +2,6 @@ from PIL import Image import os import os.path -import errno import numpy as np import sys if sys.version_info[0] == 2: diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py new file mode 100644 index 00000000000..6fff770b165 --- /dev/null +++ b/torchvision/datasets/omniglot.py @@ -0,0 +1,99 @@ +from __future__ import print_function +from PIL import Image +from os.path import join +import os +import torch.utils.data as data +from .utils import download_url, check_integrity, list_dir, list_files + + +class Omniglot(data.Dataset): + """`Omniglot `_ Dataset. + Args: + root (string): Root directory of dataset where directory + ``omniglot-py`` exists. + background (bool, optional): If True, creates dataset from the "background" set, otherwise + creates from the "evaluation" set. This terminology is defined by the authors. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset zip files from the internet and + puts it in root directory. If the zip files are already downloaded, they are not + downloaded again. + """ + folder = 'omniglot-py' + download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python' + zips_md5 = { + 'images_background': '68d2efa1b9178cc56df9314c21c6e718', + 'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' + } + + def __init__(self, root, background=True, + transform=None, target_transform=None, + download=False): + self.root = join(os.path.expanduser(root), self.folder) + self.background = background + self.transform = transform + self.target_transform = target_transform + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.target_folder = join(self.root, self._get_target_folder()) + self._alphabets = list_dir(self.target_folder) + self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] + for a in self._alphabets], []) + self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] + for idx, character in enumerate(self._characters)] + self._flat_character_images = sum(self._character_images, []) + + def __len__(self): + return len(self._flat_character_images) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target character class. + """ + image_name, character_class = self._flat_character_images[index] + image_path = join(self.target_folder, self._characters[character_class], image_name) + image = Image.open(image_path, mode='r').convert('L') + + if self.transform: + image = self.transform(image) + + if self.target_transform: + character_class = self.target_transform(character_class) + + return image, character_class + + def _check_integrity(self): + zip_filename = self._get_target_folder() + if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): + return False + return True + + def download(self): + import zipfile + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + filename = self._get_target_folder() + zip_filename = filename + '.zip' + url = self.download_url_prefix + '/' + zip_filename + download_url(url, self.root, zip_filename, self.zips_md5[filename]) + print('Extracting downloaded file: ' + join(self.root, zip_filename)) + with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file: + zip_file.extractall(self.root) + + def _get_target_folder(self): + return 'images_background' if self.background else 'images_evaluation' diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 466be647252..9fa3b0b8c9b 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -45,3 +45,49 @@ def download_url(url, root, filename, md5): print('Failed download. Trying https -> http instead.' ' Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve(url, fpath) + + +def list_dir(root, prefix=False): + """List all directories at a given root + + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = list( + filter( + lambda p: os.path.isdir(os.path.join(root, p)), + os.listdir(root) + ) + ) + + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + + return directories + + +def list_files(root, suffix, prefix=False): + """List all files ending with a suffix at a given root + + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = list( + filter( + lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), + os.listdir(root) + ) + ) + + if prefix is True: + files = [os.path.join(root, d) for d in files] + + return files