diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index f10c83d57db..8c477e64810 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -2,19 +2,12 @@ from PIL import Image import os import os.path -import numpy as np -import sys -if sys.version_info[0] == 2: - import cPickle as pickle -else: - import pickle -import collections -import torch.utils.data as data -from .utils import download_url, check_integrity, makedir_exist_ok +from .vision import VisionDataset +from .utils import download_url, makedir_exist_ok -class Caltech101(data.Dataset): +class Caltech101(VisionDataset): """`Caltech 101 `_ Dataset. Args: @@ -36,7 +29,7 @@ class Caltech101(data.Dataset): def __init__(self, root, target_type="category", transform=None, target_transform=None, download=False): - self.root = os.path.join(os.path.expanduser(root), "caltech101") + super(Caltech101, self).__init__(os.path.join(root, 'caltech101')) makedir_exist_ok(self.root) if isinstance(target_type, list): self.target_type = target_type @@ -138,19 +131,11 @@ def download(self): with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: tar.extractall(path=self.root) - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Target type: {}\n'.format(self.target_type) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + return "Target type: {target_type}".format(**self.__dict__) -class Caltech256(data.Dataset): +class Caltech256(VisionDataset): """`Caltech 256 `_ Dataset. Args: @@ -168,7 +153,7 @@ class Caltech256(data.Dataset): def __init__(self, root, transform=None, target_transform=None, download=False): - self.root = os.path.join(os.path.expanduser(root), "caltech256") + super(Caltech256, self).__init__(os.path.join(root, 'caltech256')) makedir_exist_ok(self.root) self.transform = transform self.target_transform = target_transform @@ -233,13 +218,3 @@ def download(self): # extract file with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: tar.extractall(path=self.root) - - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index e38cd6bb6cd..1c466dc0777 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,11 +1,11 @@ import torch -import torch.utils.data as data import os import PIL +from .vision import VisionDataset from .utils import download_file_from_google_drive, check_integrity -class CelebA(data.Dataset): +class CelebA(VisionDataset): """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. Args: @@ -53,7 +53,7 @@ def __init__(self, root, transform=None, target_transform=None, download=False): import pandas - self.root = os.path.expanduser(root) + super(CelebA, self).__init__(root) self.split = split if isinstance(target_type, list): self.target_type = target_type @@ -158,14 +158,6 @@ def __getitem__(self, index): def __len__(self): return len(self.attr) - def __repr__(self): - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Target type: {}\n'.format(self.target_type) - fmt_str += ' Split: {}\n'.format(self.split) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - tmp = ' Target Transforms (if any): ' - fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str + def extra_repr(self): + lines = ["Target type: {target_type}", "Split: {split}"] + return '\n'.join(lines).format(**self.__dict__) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 8a1268944eb..ea48c2fab56 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -132,25 +132,8 @@ def valid_splits(self): def split_folder(self): return os.path.join(self.root, self.split) - def __repr__(self): - head = "Dataset " + self.__class__.__name__ - body = ["Number of datapoints: {}".format(self.__len__())] - if self.root is not None: - body.append("Root location: {}".format(self.root)) - body += ["Split: {}".format(self.split)] - if hasattr(self, 'transform') and self.transform is not None: - body += self._format_transform_repr(self.transform, - "Transforms: ") - if hasattr(self, 'target_transform') and self.target_transform is not None: - body += self._format_transform_repr(self.target_transform, - "Target transforms: ") - lines = [head] + [" " * 4 + line for line in body] - return '\n'.join(lines) - - def _format_transform_repr(self, transform, head): - lines = transform.__repr__().splitlines() - return (["{}{}".format(head, lines[0])] + - ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + def extra_repr(self): + return "Split: {split}".format(**self.__dict__) def extract_tar(src, dest=None, gzip=None, delete=False): diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index a4939622803..4072142658c 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -1,5 +1,5 @@ import os -import torch.utils.data as data +from .vision import VisionDataset import numpy as np @@ -8,7 +8,7 @@ from .voc import download_extract -class SBDataset(data.Dataset): +class SBDataset(VisionDataset): """`Semantic Boundaries Dataset `_ The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset. @@ -62,10 +62,11 @@ def __init__(self, raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " "pip install scipy") + super(SBDataset, self).__init__(root) + if mode not in ("segmentation", "boundaries"): raise ValueError("Argument mode should be 'segmentation' or 'boundaries'") - self.root = os.path.expanduser(root) self.xy_transform = xy_transform self.image_set = image_set self.mode = mode @@ -121,3 +122,7 @@ def __getitem__(self, index): def __len__(self): return len(self.images) + + def extra_repr(self): + lines = ["Image set: {image_set}", "Mode: {mode}"] + return '\n'.join(lines).format(**self.__dict__)