Skip to content

Refactoring of the datasets #749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 4, 2019
Merged
22 changes: 8 additions & 14 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import os.path
import numpy as np
import sys

if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle

import torch.utils.data as data
from .vision import VisionDataset
from .utils import download_url, check_integrity


class CIFAR10(data.Dataset):
class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

Args:
Expand Down Expand Up @@ -54,9 +55,11 @@ class CIFAR10(data.Dataset):
def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False):
self.root = os.path.expanduser(root)

super(CIFAR10, self).__init__(root)
self.transform = transform
self.target_transform = target_transform

self.train = train # training set or test set

if download:
Expand Down Expand Up @@ -153,17 +156,8 @@ def download(self):
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")


class CIFAR100(CIFAR10):
Expand Down
25 changes: 8 additions & 17 deletions torchvision/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import os
from collections import namedtuple

import torch.utils.data as data
from .vision import VisionDataset
from PIL import Image


class Cityscapes(data.Dataset):
class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.

Args:
Expand Down Expand Up @@ -93,12 +93,12 @@ class Cityscapes(data.Dataset):

def __init__(self, root, split='train', mode='fine', target_type='instance',
transform=None, target_transform=None):
self.root = os.path.expanduser(root)
super(Cityscapes, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
self.targets_dir = os.path.join(self.root, self.mode, split)
self.transform = transform
self.target_transform = target_transform
self.target_type = target_type
self.split = split
self.images = []
Expand Down Expand Up @@ -171,18 +171,9 @@ def __getitem__(self, index):
def __len__(self):
return len(self.images)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Mode: {}\n'.format(self.mode)
fmt_str += ' Type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
return '\n'.join(lines).format(**self.__dict__)

def _load_json(self, path):
with open(path, 'r') as file:
Expand Down
29 changes: 10 additions & 19 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch.utils.data as data
from .vision import VisionDataset
from PIL import Image
import os
import os.path


class CocoCaptions(data.Dataset):
class CocoCaptions(VisionDataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.

Args:
Expand Down Expand Up @@ -42,13 +42,14 @@ class CocoCaptions(data.Dataset):
u'A mountain view with a plume of smoke in the background']

"""

def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoCaptions, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
from pycocotools.coco import COCO
self.root = os.path.expanduser(root)
self.coco = COCO(annFile)
self.ids = list(self.coco.imgs.keys())
self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
"""
Expand Down Expand Up @@ -79,7 +80,7 @@ def __len__(self):
return len(self.ids)


class CocoDetection(data.Dataset):
class CocoDetection(VisionDataset):
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.

Args:
Expand All @@ -92,12 +93,12 @@ class CocoDetection(data.Dataset):
"""

def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
from pycocotools.coco import COCO
self.root = root
self.coco = COCO(annFile)
self.ids = list(self.coco.imgs.keys())
self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
"""
Expand Down Expand Up @@ -125,13 +126,3 @@ def __getitem__(self, index):

def __len__(self):
return len(self.ids)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
16 changes: 5 additions & 11 deletions torchvision/datasets/fakedata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import torch.utils.data as data
from .vision import VisionDataset
from .. import transforms


class FakeData(data.Dataset):
class FakeData(VisionDataset):
"""A fake dataset that returns randomly generated images and returns them as PIL images

Args:
Expand All @@ -21,6 +21,9 @@ class FakeData(data.Dataset):

def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
transform=None, target_transform=None, random_offset=0):
super(FakeData, self).__init__(None)
self.transform = transform
self.target_transform = target_transform
self.size = size
self.num_classes = num_classes
self.image_size = image_size
Expand Down Expand Up @@ -54,12 +57,3 @@ def __getitem__(self, index):

def __len__(self):
return self.size

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
16 changes: 9 additions & 7 deletions torchvision/datasets/flickr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import glob
import os
import torch.utils.data as data
from .vision import VisionDataset


class Flickr8kParser(html_parser.HTMLParser):
Expand Down Expand Up @@ -50,7 +50,7 @@ def handle_data(self, data):
self.annotations[img_id].append(data.strip())


class Flickr8k(data.Dataset):
class Flickr8k(VisionDataset):
"""`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.

Args:
Expand All @@ -61,11 +61,12 @@ class Flickr8k(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.ann_file = os.path.expanduser(ann_file)
super(Flickr8k, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.ann_file = os.path.expanduser(ann_file)

# Read annotations and store in a dict
parser = Flickr8kParser(self.root)
Expand Down Expand Up @@ -101,7 +102,7 @@ def __len__(self):
return len(self.ids)


class Flickr30k(data.Dataset):
class Flickr30k(VisionDataset):
"""`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset.

Args:
Expand All @@ -112,11 +113,12 @@ class Flickr30k(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""

def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.ann_file = os.path.expanduser(ann_file)
super(Flickr30k, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.ann_file = os.path.expanduser(ann_file)

# Read annotations and store in a dict
self.annotations = defaultdict(list)
Expand Down
26 changes: 8 additions & 18 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch.utils.data as data
from .vision import VisionDataset

from PIL import Image

Expand Down Expand Up @@ -51,7 +51,7 @@ def make_dataset(dir, class_to_idx, extensions):
return images


class DatasetFolder(data.Dataset):
class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::

root/class_x/xxx.ext
Expand Down Expand Up @@ -80,13 +80,15 @@ class DatasetFolder(data.Dataset):
"""

def __init__(self, root, loader, extensions, transform=None, target_transform=None):
super(DatasetFolder, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
classes, class_to_idx = self._find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))

self.root = root
self.loader = loader
self.extensions = extensions

Expand All @@ -95,9 +97,6 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No
self.samples = samples
self.targets = [s[1] for s in samples]

self.transform = transform
self.target_transform = target_transform

def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Expand Down Expand Up @@ -140,16 +139,6 @@ def __getitem__(self, index):
def __len__(self):
return len(self.samples)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str


IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

Expand Down Expand Up @@ -202,6 +191,7 @@ class ImageFolder(DatasetFolder):
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""

def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
Expand Down
Loading