diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py
index f10c83d57db..8c477e64810 100644
--- a/torchvision/datasets/caltech.py
+++ b/torchvision/datasets/caltech.py
@@ -2,19 +2,12 @@
from PIL import Image
import os
import os.path
-import numpy as np
-import sys
-if sys.version_info[0] == 2:
- import cPickle as pickle
-else:
- import pickle
-import collections
-import torch.utils.data as data
-from .utils import download_url, check_integrity, makedir_exist_ok
+from .vision import VisionDataset
+from .utils import download_url, makedir_exist_ok
-class Caltech101(data.Dataset):
+class Caltech101(VisionDataset):
"""`Caltech 101 `_ Dataset.
Args:
@@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
def __init__(self, root, target_type="category",
transform=None, target_transform=None,
download=False):
- self.root = os.path.join(os.path.expanduser(root), "caltech101")
+ super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
@@ -138,19 +131,11 @@ def download(self):
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)
- def __repr__(self):
- fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
- fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
- fmt_str += ' Target type: {}\n'.format(self.target_type)
- fmt_str += ' Root Location: {}\n'.format(self.root)
- tmp = ' Transforms (if any): '
- fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- tmp = ' Target Transforms (if any): '
- fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- return fmt_str
+ def extra_repr(self):
+ return "Target type: {target_type}".format(**self.__dict__)
-class Caltech256(data.Dataset):
+class Caltech256(VisionDataset):
"""`Caltech 256 `_ Dataset.
Args:
@@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
def __init__(self, root,
transform=None, target_transform=None,
download=False):
- self.root = os.path.join(os.path.expanduser(root), "caltech256")
+ super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
makedir_exist_ok(self.root)
self.transform = transform
self.target_transform = target_transform
@@ -233,13 +218,3 @@ def download(self):
# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)
-
- def __repr__(self):
- fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
- fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
- fmt_str += ' Root Location: {}\n'.format(self.root)
- tmp = ' Transforms (if any): '
- fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- tmp = ' Target Transforms (if any): '
- fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- return fmt_str
diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py
index e38cd6bb6cd..1c466dc0777 100644
--- a/torchvision/datasets/celeba.py
+++ b/torchvision/datasets/celeba.py
@@ -1,11 +1,11 @@
import torch
-import torch.utils.data as data
import os
import PIL
+from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity
-class CelebA(data.Dataset):
+class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset.
Args:
@@ -53,7 +53,7 @@ def __init__(self, root,
transform=None, target_transform=None,
download=False):
import pandas
- self.root = os.path.expanduser(root)
+ super(CelebA, self).__init__(root)
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
@@ -158,14 +158,6 @@ def __getitem__(self, index):
def __len__(self):
return len(self.attr)
- def __repr__(self):
- fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
- fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
- fmt_str += ' Target type: {}\n'.format(self.target_type)
- fmt_str += ' Split: {}\n'.format(self.split)
- fmt_str += ' Root Location: {}\n'.format(self.root)
- tmp = ' Transforms (if any): '
- fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- tmp = ' Target Transforms (if any): '
- fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
- return fmt_str
+ def extra_repr(self):
+ lines = ["Target type: {target_type}", "Split: {split}"]
+ return '\n'.join(lines).format(**self.__dict__)
diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py
index 8a1268944eb..ea48c2fab56 100644
--- a/torchvision/datasets/imagenet.py
+++ b/torchvision/datasets/imagenet.py
@@ -132,25 +132,8 @@ def valid_splits(self):
def split_folder(self):
return os.path.join(self.root, self.split)
- def __repr__(self):
- head = "Dataset " + self.__class__.__name__
- body = ["Number of datapoints: {}".format(self.__len__())]
- if self.root is not None:
- body.append("Root location: {}".format(self.root))
- body += ["Split: {}".format(self.split)]
- if hasattr(self, 'transform') and self.transform is not None:
- body += self._format_transform_repr(self.transform,
- "Transforms: ")
- if hasattr(self, 'target_transform') and self.target_transform is not None:
- body += self._format_transform_repr(self.target_transform,
- "Target transforms: ")
- lines = [head] + [" " * 4 + line for line in body]
- return '\n'.join(lines)
-
- def _format_transform_repr(self, transform, head):
- lines = transform.__repr__().splitlines()
- return (["{}{}".format(head, lines[0])] +
- ["{}{}".format(" " * len(head), line) for line in lines[1:]])
+ def extra_repr(self):
+ return "Split: {split}".format(**self.__dict__)
def extract_tar(src, dest=None, gzip=None, delete=False):
diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py
index a4939622803..4072142658c 100644
--- a/torchvision/datasets/sbd.py
+++ b/torchvision/datasets/sbd.py
@@ -1,5 +1,5 @@
import os
-import torch.utils.data as data
+from .vision import VisionDataset
import numpy as np
@@ -8,7 +8,7 @@
from .voc import download_extract
-class SBDataset(data.Dataset):
+class SBDataset(VisionDataset):
"""`Semantic Boundaries Dataset `_
The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
@@ -62,10 +62,11 @@ def __init__(self,
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy")
+ super(SBDataset, self).__init__(root)
+
if mode not in ("segmentation", "boundaries"):
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")
- self.root = os.path.expanduser(root)
self.xy_transform = xy_transform
self.image_set = image_set
self.mode = mode
@@ -121,3 +122,7 @@ def __getitem__(self, index):
def __len__(self):
return len(self.images)
+
+ def extra_repr(self):
+ lines = ["Image set: {image_set}", "Mode: {mode}"]
+ return '\n'.join(lines).format(**self.__dict__)