diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index 8441c9f7814..307e8f60a06 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -4,16 +4,17 @@ import os.path import numpy as np import sys + if sys.version_info[0] == 2: import cPickle as pickle else: import pickle -import torch.utils.data as data +from .vision import VisionDataset from .utils import download_url, check_integrity -class CIFAR10(data.Dataset): +class CIFAR10(VisionDataset): """`CIFAR10 `_ Dataset. Args: @@ -54,9 +55,11 @@ class CIFAR10(data.Dataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False): - self.root = os.path.expanduser(root) + + super(CIFAR10, self).__init__(root) self.transform = transform self.target_transform = target_transform + self.train = train # training set or test set if download: @@ -153,17 +156,8 @@ def download(self): with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: tar.extractall(path=self.root) - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - tmp = 'train' if self.train is True else 'test' - fmt_str += ' Split: {}\n'.format(tmp) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + return "Split: {}".format("Train" if self.train is True else "Test") class CIFAR100(CIFAR10): diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index c8466b77b27..4f84839063b 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -2,11 +2,11 @@ import os from collections import namedtuple -import torch.utils.data as data +from .vision import VisionDataset from PIL import Image -class Cityscapes(data.Dataset): +class Cityscapes(VisionDataset): """`Cityscapes `_ Dataset. Args: @@ -93,12 +93,12 @@ class Cityscapes(data.Dataset): def __init__(self, root, split='train', mode='fine', target_type='instance', transform=None, target_transform=None): - self.root = os.path.expanduser(root) + super(Cityscapes, self).__init__(root) + self.transform = transform + self.target_transform = target_transform self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' self.images_dir = os.path.join(self.root, 'leftImg8bit', split) self.targets_dir = os.path.join(self.root, self.mode, split) - self.transform = transform - self.target_transform = target_transform self.target_type = target_type self.split = split self.images = [] @@ -171,18 +171,9 @@ def __getitem__(self, index): def __len__(self): return len(self.images) - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Split: {}\n'.format(self.split) - fmt_str += ' Mode: {}\n'.format(self.mode) - fmt_str += ' Type: {}\n'.format(self.target_type) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] + return '\n'.join(lines).format(**self.__dict__) def _load_json(self, path): with open(path, 'r') as file: diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index c2e4e30c7d3..219bdfdb928 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -1,10 +1,10 @@ -import torch.utils.data as data +from .vision import VisionDataset from PIL import Image import os import os.path -class CocoCaptions(data.Dataset): +class CocoCaptions(VisionDataset): """`MS Coco Captions `_ Dataset. Args: @@ -42,13 +42,14 @@ class CocoCaptions(data.Dataset): u'A mountain view with a plume of smoke in the background'] """ + def __init__(self, root, annFile, transform=None, target_transform=None): + super(CocoCaptions, self).__init__(root) + self.transform = transform + self.target_transform = target_transform from pycocotools.coco import COCO - self.root = os.path.expanduser(root) self.coco = COCO(annFile) self.ids = list(self.coco.imgs.keys()) - self.transform = transform - self.target_transform = target_transform def __getitem__(self, index): """ @@ -79,7 +80,7 @@ def __len__(self): return len(self.ids) -class CocoDetection(data.Dataset): +class CocoDetection(VisionDataset): """`MS Coco Detection `_ Dataset. Args: @@ -92,12 +93,12 @@ class CocoDetection(data.Dataset): """ def __init__(self, root, annFile, transform=None, target_transform=None): + super(CocoDetection, self).__init__(root) + self.transform = transform + self.target_transform = target_transform from pycocotools.coco import COCO - self.root = root self.coco = COCO(annFile) self.ids = list(self.coco.imgs.keys()) - self.transform = transform - self.target_transform = target_transform def __getitem__(self, index): """ @@ -125,13 +126,3 @@ def __getitem__(self, index): def __len__(self): return len(self.ids) - - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str diff --git a/torchvision/datasets/fakedata.py b/torchvision/datasets/fakedata.py index 9d3566f9a9f..dc390be6d31 100644 --- a/torchvision/datasets/fakedata.py +++ b/torchvision/datasets/fakedata.py @@ -1,9 +1,9 @@ import torch -import torch.utils.data as data +from .vision import VisionDataset from .. import transforms -class FakeData(data.Dataset): +class FakeData(VisionDataset): """A fake dataset that returns randomly generated images and returns them as PIL images Args: @@ -21,6 +21,9 @@ class FakeData(data.Dataset): def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None, random_offset=0): + super(FakeData, self).__init__(None) + self.transform = transform + self.target_transform = target_transform self.size = size self.num_classes = num_classes self.image_size = image_size @@ -54,12 +57,3 @@ def __getitem__(self, index): def __len__(self): return self.size - - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str diff --git a/torchvision/datasets/flickr.py b/torchvision/datasets/flickr.py index 70b0bb1e61a..9ab5e93140a 100644 --- a/torchvision/datasets/flickr.py +++ b/torchvision/datasets/flickr.py @@ -4,7 +4,7 @@ import glob import os -import torch.utils.data as data +from .vision import VisionDataset class Flickr8kParser(html_parser.HTMLParser): @@ -50,7 +50,7 @@ def handle_data(self, data): self.annotations[img_id].append(data.strip()) -class Flickr8k(data.Dataset): +class Flickr8k(VisionDataset): """`Flickr8k Entities `_ Dataset. Args: @@ -61,11 +61,12 @@ class Flickr8k(data.Dataset): target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ + def __init__(self, root, ann_file, transform=None, target_transform=None): - self.root = os.path.expanduser(root) - self.ann_file = os.path.expanduser(ann_file) + super(Flickr8k, self).__init__(root) self.transform = transform self.target_transform = target_transform + self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict parser = Flickr8kParser(self.root) @@ -101,7 +102,7 @@ def __len__(self): return len(self.ids) -class Flickr30k(data.Dataset): +class Flickr30k(VisionDataset): """`Flickr30k Entities `_ Dataset. Args: @@ -112,11 +113,12 @@ class Flickr30k(data.Dataset): target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ + def __init__(self, root, ann_file, transform=None, target_transform=None): - self.root = os.path.expanduser(root) - self.ann_file = os.path.expanduser(ann_file) + super(Flickr30k, self).__init__(root) self.transform = transform self.target_transform = target_transform + self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict self.annotations = defaultdict(list) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 5b0411adbc5..eaad680f7ed 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,4 +1,4 @@ -import torch.utils.data as data +from .vision import VisionDataset from PIL import Image @@ -51,7 +51,7 @@ def make_dataset(dir, class_to_idx, extensions): return images -class DatasetFolder(data.Dataset): +class DatasetFolder(VisionDataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext @@ -80,13 +80,15 @@ class DatasetFolder(data.Dataset): """ def __init__(self, root, loader, extensions, transform=None, target_transform=None): + super(DatasetFolder, self).__init__(root) + self.transform = transform + self.target_transform = target_transform classes, class_to_idx = self._find_classes(root) samples = make_dataset(root, class_to_idx, extensions) if len(samples) == 0: - raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" - "Supported extensions are: " + ",".join(extensions))) + raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + "Supported extensions are: " + ",".join(extensions))) - self.root = root self.loader = loader self.extensions = extensions @@ -95,9 +97,6 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No self.samples = samples self.targets = [s[1] for s in samples] - self.transform = transform - self.target_transform = target_transform - def _find_classes(self, dir): """ Finds the class folders in a dataset. @@ -140,16 +139,6 @@ def __getitem__(self, index): def __len__(self): return len(self.samples) - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str - IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp'] @@ -202,6 +191,7 @@ class ImageFolder(DatasetFolder): class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ + def __init__(self, root, transform=None, target_transform=None, loader=default_loader): super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index 0d63a061a1f..6fba3fd49da 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -1,20 +1,21 @@ -import torch.utils.data as data +from .vision import VisionDataset from PIL import Image import os import os.path import six import string import sys + if sys.version_info[0] == 2: import cPickle as pickle else: import pickle -class LSUNClass(data.Dataset): +class LSUNClass(VisionDataset): def __init__(self, root, transform=None, target_transform=None): import lmdb - self.root = os.path.expanduser(root) + super(LSUNClass, self).__init__(root) self.transform = transform self.target_transform = target_transform @@ -52,11 +53,8 @@ def __getitem__(self, index): def __len__(self): return self.length - def __repr__(self): - return self.__class__.__name__ + ' (' + self.root + ')' - -class LSUN(data.Dataset): +class LSUN(VisionDataset): """ `LSUN `_ dataset. @@ -72,13 +70,13 @@ class LSUN(data.Dataset): def __init__(self, root, classes='train', transform=None, target_transform=None): + super(LSUN, self).__init__(root) + self.transform = transform + self.target_transform = target_transform categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower'] dset_opts = ['train', 'val', 'test'] - self.root = os.path.expanduser(root) - self.transform = transform - self.target_transform = target_transform if type(classes) == str and classes in dset_opts: if classes == 'test': @@ -91,15 +89,15 @@ def __init__(self, root, classes='train', c_short.pop(len(c_short) - 1) c_short = '_'.join(c_short) if c_short not in categories: - raise(ValueError('Unknown LSUN class: ' + c_short + '.' - 'Options are: ' + str(categories))) + raise (ValueError('Unknown LSUN class: ' + c_short + '.' + 'Options are: ' + str(categories))) c_short = c.split('_') c_short = c_short.pop(len(c_short) - 1) if c_short not in dset_opts: - raise(ValueError('Unknown postfix: ' + c_short + '.' - 'Options are: ' + str(dset_opts))) + raise (ValueError('Unknown postfix: ' + c_short + '.' + 'Options are: ' + str(dset_opts))) else: - raise(ValueError('Unknown option for classes')) + raise (ValueError('Unknown option for classes')) self.classes = classes # for each class, create an LSUNClassDataset @@ -145,13 +143,5 @@ def __getitem__(self, index): def __len__(self): return self.length - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - fmt_str += ' Classes: {}\n'.format(self.classes) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + return "Classes: {classes}".format(**self.__dict__) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 1ae37db6f6d..50c99eeea68 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -1,6 +1,6 @@ from __future__ import print_function +from .vision import VisionDataset import warnings -import torch.utils.data as data from PIL import Image import os import os.path @@ -11,7 +11,7 @@ from .utils import download_url, makedir_exist_ok -class MNIST(data.Dataset): +class MNIST(VisionDataset): """`MNIST `_ Dataset. Args: @@ -59,7 +59,7 @@ def test_data(self): return self.data def __init__(self, root, train=True, transform=None, target_transform=None, download=False): - self.root = os.path.expanduser(root) + super(MNIST, self).__init__(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set @@ -115,8 +115,10 @@ def class_to_idx(self): return {_class: i for i, _class in enumerate(self.classes)} def _check_exists(self): - return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \ - os.path.exists(os.path.join(self.processed_folder, self.test_file)) + return (os.path.exists(os.path.join(self.processed_folder, + self.training_file)) and + os.path.exists(os.path.join(self.processed_folder, + self.test_file))) @staticmethod def extract_gzip(gzip_path, remove_finished=False): @@ -161,17 +163,8 @@ def download(self): print('Done!') - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - tmp = 'train' if self.train is True else 'test' - fmt_str += ' Split: {}\n'.format(tmp) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + return "Split: {}".format("Train" if self.train is True else "Test") class FashionMNIST(MNIST): diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 6fff770b165..98a76f18c49 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -2,11 +2,11 @@ from PIL import Image from os.path import join import os -import torch.utils.data as data +from .vision import VisionDataset from .utils import download_url, check_integrity, list_dir, list_files -class Omniglot(data.Dataset): +class Omniglot(VisionDataset): """`Omniglot `_ Dataset. Args: root (string): Root directory of dataset where directory @@ -31,10 +31,10 @@ class Omniglot(data.Dataset): def __init__(self, root, background=True, transform=None, target_transform=None, download=False): - self.root = join(os.path.expanduser(root), self.folder) - self.background = background + super(Omniglot, self).__init__(join(os.path.expanduser(root), self.folder)) self.transform = transform self.target_transform = target_transform + self.background = background if download: self.download() diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index 5752278ec13..1d02cbeaf0c 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -3,12 +3,12 @@ from PIL import Image import torch -import torch.utils.data as data +from .vision import VisionDataset from .utils import download_url -class PhotoTour(data.Dataset): +class PhotoTour(VisionDataset): """`Learning Local Image Descriptors Data `_ Dataset. @@ -65,14 +65,14 @@ class PhotoTour(data.Dataset): matches_files = 'm50_100000_100000_0.txt' def __init__(self, root, name, train=True, transform=None, download=False): - self.root = os.path.expanduser(root) + super(PhotoTour, self).__init__(root) + self.transform = transform self.name = name self.data_dir = os.path.join(self.root, name) self.data_down = os.path.join(self.root, '{}.zip'.format(name)) self.data_file = os.path.join(self.root, '{}.pt'.format(name)) self.train = train - self.transform = transform self.mean = self.mean[name] self.std = self.std[name] @@ -151,20 +151,14 @@ def download(self): with open(self.data_file, 'wb') as f: torch.save(dataset, f) - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - tmp = 'train' if self.train is True else 'test' - fmt_str += ' Split: {}\n'.format(tmp) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + return "Split: {}".format("Train" if self.train is True else "Test") def read_image_file(data_dir, image_ext, n): """Return a Tensor containing the patches """ + def PIL2array(_img): """Convert PIL image type to numpy 2D array """ @@ -211,5 +205,6 @@ def read_matches_files(data_dir, matches_file): with open(os.path.join(data_dir, matches_file), 'r') as f: for line in f: line_split = line.split() - matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])]) + matches.append([int(line_split[0]), int(line_split[3]), + int(line_split[1] == line_split[4])]) return torch.LongTensor(matches) diff --git a/torchvision/datasets/sbu.py b/torchvision/datasets/sbu.py index e1dde981830..01c06940789 100644 --- a/torchvision/datasets/sbu.py +++ b/torchvision/datasets/sbu.py @@ -3,10 +3,10 @@ from .utils import download_url, check_integrity import os -import torch.utils.data as data +from .vision import VisionDataset -class SBU(data.Dataset): +class SBU(VisionDataset): """`SBU Captioned Photo `_ Dataset. Args: @@ -24,8 +24,9 @@ class SBU(data.Dataset): filename = "SBUCaptionedPhotoDataset.tar.gz" md5_checksum = '9aec147b3488753cf758b4d493422285' - def __init__(self, root, transform=None, target_transform=None, download=True): - self.root = os.path.expanduser(root) + def __init__(self, root, transform=None, target_transform=None, + download=True): + super(SBU, self).__init__(root) self.transform = transform self.target_transform = target_transform diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index ce1136de9f8..207533964cb 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -3,11 +3,11 @@ import os import os.path import numpy as np -import torch.utils.data as data +from .vision import VisionDataset from .utils import download_url, check_integrity -class SEMEION(data.Dataset): +class SEMEION(VisionDataset): """`SEMEION `_ Dataset. Args: root (string): Root directory of dataset where directory @@ -24,8 +24,9 @@ class SEMEION(data.Dataset): filename = "semeion.data" md5_checksum = 'cb545d371d2ce14ec121470795a77432' - def __init__(self, root, transform=None, target_transform=None, download=True): - self.root = os.path.expanduser(root) + def __init__(self, root, transform=None, target_transform=None, + download=True): + super(SEMEION, self).__init__(root) self.transform = transform self.target_transform = target_transform @@ -84,13 +85,3 @@ def download(self): root = self.root download_url(self.url, root, self.filename, self.md5_checksum) - - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 0590ab8b21f..b5b303407a2 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -129,13 +129,5 @@ def __loadfile(self, data_file, labels_file=None): return images, labels - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Split: {}\n'.format(self.split) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + return "Split: {split}".format(**self.__dict__) diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index 6c1fb0a3adf..972125d74fb 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -1,5 +1,5 @@ from __future__ import print_function -import torch.utils.data as data +from .vision import VisionDataset from PIL import Image import os import os.path @@ -7,7 +7,7 @@ from .utils import download_url, check_integrity -class SVHN(data.Dataset): +class SVHN(VisionDataset): """`SVHN `_ Dataset. Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which @@ -41,7 +41,7 @@ class SVHN(data.Dataset): def __init__(self, root, split='train', transform=None, target_transform=None, download=False): - self.root = os.path.expanduser(root) + super(SVHN, self).__init__(root) self.transform = transform self.target_transform = target_transform self.split = split # training set or test set or extra set @@ -116,13 +116,5 @@ def download(self): md5 = self.split_list[self.split][2] download_url(self.url, self.root, self.filename, md5) - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Split: {}\n'.format(self.split) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + return "Split: {split}".format(**self.__dict__) diff --git a/torchvision/datasets/vision.py b/torchvision/datasets/vision.py new file mode 100644 index 00000000000..168388aadde --- /dev/null +++ b/torchvision/datasets/vision.py @@ -0,0 +1,41 @@ +import os +import torch +import torch.utils.data as data + + +class VisionDataset(data.Dataset): + _repr_indent = 4 + + def __init__(self, root): + if isinstance(root, torch._six.string_classes): + root = os.path.expanduser(root) + self.root = root + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def __repr__(self): + head = "Dataset " + self.__class__.__name__ + body = ["Number of datapoints: {}".format(self.__len__())] + if self.root is not None: + body.append("Root location: {}".format(self.root)) + body += self.extra_repr().splitlines() + if hasattr(self, 'transform') and self.transform is not None: + body += self._format_transform_repr(self.transform, + "Transforms: ") + if hasattr(self, 'target_transform') and self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, + "Target transforms: ") + lines = [head] + [" " * self._repr_indent + line for line in body] + return '\n'.join(lines) + + def _format_transform_repr(self, transform, head): + lines = transform.__repr__().splitlines() + return (["{}{}".format(head, lines[0])] + + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + + def extra_repr(self): + return "" diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 1ec44b3c9ff..96b96b459d4 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -2,7 +2,8 @@ import sys import tarfile import collections -import torch.utils.data as data +from .vision import VisionDataset + if sys.version_info[0] == 2: import xml.etree.cElementTree as ET else: @@ -51,7 +52,7 @@ } -class VOCSegmentation(data.Dataset): +class VOCSegmentation(VisionDataset): """`Pascal VOC `_ Segmentation Dataset. Args: @@ -74,13 +75,13 @@ def __init__(self, download=False, transform=None, target_transform=None): - self.root = os.path.expanduser(root) + super(VOCSegmentation, self).__init__(root) + self.transform = transform + self.target_transform = target_transform self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] - self.transform = transform - self.target_transform = target_transform self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) @@ -133,7 +134,7 @@ def __len__(self): return len(self.images) -class VOCDetection(data.Dataset): +class VOCDetection(VisionDataset): """`Pascal VOC `_ Detection Dataset. Args: @@ -157,13 +158,13 @@ def __init__(self, download=False, transform=None, target_transform=None): - self.root = os.path.expanduser(root) + super(VOCDetection, self).__init__(root) + self.transform = transform + self.target_transform = target_transform self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] - self.transform = transform - self.target_transform = target_transform self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]['base_dir'] @@ -228,8 +229,8 @@ def parse_voc_xml(self, node): def_dic[ind].append(v) voc_dict = { node.tag: - {ind: v[0] if len(v) == 1 else v - for ind, v in def_dic.items()} + {ind: v[0] if len(v) == 1 else v + for ind, v in def_dic.items()} } if node.text: text = node.text.strip()