Skip to content

Commit dac9efa

Browse files
activatedgeekfmassa
authored andcommitted
Omniglot Dataset (#323)
* Add basic Omniglot dataset loader * Remove unused import * Add Omniglot random pair to sample pair of characters * Precompute random set of pairs, deterministic after object instantiation * Export OmniglotRandomPair via the datasets module interfact * Fix naming convention, use sum instead of reduce * Fix downloading to not download everything, fix Python2 syntax * Fix end line lint * Add random_seed, syntax fixes * Remove randomized pair, take up as a separate generic wrapper * Fix master conflict
1 parent 7044049 commit dac9efa

File tree

5 files changed

+149
-3
lines changed

5 files changed

+149
-3
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ torchvision.egg-info/
55
*/**/*.pyc
66
*/**/*~
77
*~
8-
docs/build
8+
docs/build

torchvision/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from .phototour import PhotoTour
99
from .fakedata import FakeData
1010
from .semeion import SEMEION
11+
from .omniglot import Omniglot
1112

1213
__all__ = ('LSUN', 'LSUNClass',
1314
'ImageFolder', 'FakeData',
1415
'CocoCaptions', 'CocoDetection',
1516
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
16-
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')
17+
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
18+
'Omniglot')

torchvision/datasets/cifar.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from PIL import Image
33
import os
44
import os.path
5-
import errno
65
import numpy as np
76
import sys
87
if sys.version_info[0] == 2:

torchvision/datasets/omniglot.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import print_function
2+
from PIL import Image
3+
from os.path import join
4+
import os
5+
import torch.utils.data as data
6+
from .utils import download_url, check_integrity, list_dir, list_files
7+
8+
9+
class Omniglot(data.Dataset):
10+
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
11+
Args:
12+
root (string): Root directory of dataset where directory
13+
``omniglot-py`` exists.
14+
background (bool, optional): If True, creates dataset from the "background" set, otherwise
15+
creates from the "evaluation" set. This terminology is defined by the authors.
16+
transform (callable, optional): A function/transform that takes in an PIL image
17+
and returns a transformed version. E.g, ``transforms.RandomCrop``
18+
target_transform (callable, optional): A function/transform that takes in the
19+
target and transforms it.
20+
download (bool, optional): If true, downloads the dataset zip files from the internet and
21+
puts it in root directory. If the zip files are already downloaded, they are not
22+
downloaded again.
23+
"""
24+
folder = 'omniglot-py'
25+
download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
26+
zips_md5 = {
27+
'images_background': '68d2efa1b9178cc56df9314c21c6e718',
28+
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
29+
}
30+
31+
def __init__(self, root, background=True,
32+
transform=None, target_transform=None,
33+
download=False):
34+
self.root = join(os.path.expanduser(root), self.folder)
35+
self.background = background
36+
self.transform = transform
37+
self.target_transform = target_transform
38+
39+
if download:
40+
self.download()
41+
42+
if not self._check_integrity():
43+
raise RuntimeError('Dataset not found or corrupted.' +
44+
' You can use download=True to download it')
45+
46+
self.target_folder = join(self.root, self._get_target_folder())
47+
self._alphabets = list_dir(self.target_folder)
48+
self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
49+
for a in self._alphabets], [])
50+
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
51+
for idx, character in enumerate(self._characters)]
52+
self._flat_character_images = sum(self._character_images, [])
53+
54+
def __len__(self):
55+
return len(self._flat_character_images)
56+
57+
def __getitem__(self, index):
58+
"""
59+
Args:
60+
index (int): Index
61+
62+
Returns:
63+
tuple: (image, target) where target is index of the target character class.
64+
"""
65+
image_name, character_class = self._flat_character_images[index]
66+
image_path = join(self.target_folder, self._characters[character_class], image_name)
67+
image = Image.open(image_path, mode='r').convert('L')
68+
69+
if self.transform:
70+
image = self.transform(image)
71+
72+
if self.target_transform:
73+
character_class = self.target_transform(character_class)
74+
75+
return image, character_class
76+
77+
def _check_integrity(self):
78+
zip_filename = self._get_target_folder()
79+
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
80+
return False
81+
return True
82+
83+
def download(self):
84+
import zipfile
85+
86+
if self._check_integrity():
87+
print('Files already downloaded and verified')
88+
return
89+
90+
filename = self._get_target_folder()
91+
zip_filename = filename + '.zip'
92+
url = self.download_url_prefix + '/' + zip_filename
93+
download_url(url, self.root, zip_filename, self.zips_md5[filename])
94+
print('Extracting downloaded file: ' + join(self.root, zip_filename))
95+
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
96+
zip_file.extractall(self.root)
97+
98+
def _get_target_folder(self):
99+
return 'images_background' if self.background else 'images_evaluation'

torchvision/datasets/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,49 @@ def download_url(url, root, filename, md5):
4545
print('Failed download. Trying https -> http instead.'
4646
' Downloading ' + url + ' to ' + fpath)
4747
urllib.request.urlretrieve(url, fpath)
48+
49+
50+
def list_dir(root, prefix=False):
51+
"""List all directories at a given root
52+
53+
Args:
54+
root (str): Path to directory whose folders need to be listed
55+
prefix (bool, optional): If true, prepends the path to each result, otherwise
56+
only returns the name of the directories found
57+
"""
58+
root = os.path.expanduser(root)
59+
directories = list(
60+
filter(
61+
lambda p: os.path.isdir(os.path.join(root, p)),
62+
os.listdir(root)
63+
)
64+
)
65+
66+
if prefix is True:
67+
directories = [os.path.join(root, d) for d in directories]
68+
69+
return directories
70+
71+
72+
def list_files(root, suffix, prefix=False):
73+
"""List all files ending with a suffix at a given root
74+
75+
Args:
76+
root (str): Path to directory whose folders need to be listed
77+
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
78+
It uses the Python "str.endswith" method and is passed directly
79+
prefix (bool, optional): If true, prepends the path to each result, otherwise
80+
only returns the name of the files found
81+
"""
82+
root = os.path.expanduser(root)
83+
files = list(
84+
filter(
85+
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
86+
os.listdir(root)
87+
)
88+
)
89+
90+
if prefix is True:
91+
files = [os.path.join(root, d) for d in files]
92+
93+
return files

0 commit comments

Comments
 (0)