Skip to content

Commit bbaa1b0

Browse files
Philip Meierfmassa
Philip Meier
authored andcommitted
added support for VisionDataset (#838)
1 parent 8759f30 commit bbaa1b0

File tree

4 files changed

+24
-69
lines changed

4 files changed

+24
-69
lines changed

torchvision/datasets/caltech.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,12 @@
22
from PIL import Image
33
import os
44
import os.path
5-
import numpy as np
6-
import sys
7-
if sys.version_info[0] == 2:
8-
import cPickle as pickle
9-
else:
10-
import pickle
11-
import collections
125

13-
import torch.utils.data as data
14-
from .utils import download_url, check_integrity, makedir_exist_ok
6+
from .vision import VisionDataset
7+
from .utils import download_url, makedir_exist_ok
158

169

17-
class Caltech101(data.Dataset):
10+
class Caltech101(VisionDataset):
1811
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
1912
2013
Args:
@@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
3629
def __init__(self, root, target_type="category",
3730
transform=None, target_transform=None,
3831
download=False):
39-
self.root = os.path.join(os.path.expanduser(root), "caltech101")
32+
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
4033
makedir_exist_ok(self.root)
4134
if isinstance(target_type, list):
4235
self.target_type = target_type
@@ -138,19 +131,11 @@ def download(self):
138131
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
139132
tar.extractall(path=self.root)
140133

141-
def __repr__(self):
142-
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
143-
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
144-
fmt_str += ' Target type: {}\n'.format(self.target_type)
145-
fmt_str += ' Root Location: {}\n'.format(self.root)
146-
tmp = ' Transforms (if any): '
147-
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
148-
tmp = ' Target Transforms (if any): '
149-
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
150-
return fmt_str
134+
def extra_repr(self):
135+
return "Target type: {target_type}".format(**self.__dict__)
151136

152137

153-
class Caltech256(data.Dataset):
138+
class Caltech256(VisionDataset):
154139
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
155140
156141
Args:
@@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
168153
def __init__(self, root,
169154
transform=None, target_transform=None,
170155
download=False):
171-
self.root = os.path.join(os.path.expanduser(root), "caltech256")
156+
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
172157
makedir_exist_ok(self.root)
173158
self.transform = transform
174159
self.target_transform = target_transform
@@ -233,13 +218,3 @@ def download(self):
233218
# extract file
234219
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
235220
tar.extractall(path=self.root)
236-
237-
def __repr__(self):
238-
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
239-
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
240-
fmt_str += ' Root Location: {}\n'.format(self.root)
241-
tmp = ' Transforms (if any): '
242-
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
243-
tmp = ' Target Transforms (if any): '
244-
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
245-
return fmt_str

torchvision/datasets/celeba.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import torch
2-
import torch.utils.data as data
32
import os
43
import PIL
4+
from .vision import VisionDataset
55
from .utils import download_file_from_google_drive, check_integrity
66

77

8-
class CelebA(data.Dataset):
8+
class CelebA(VisionDataset):
99
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
1010
1111
Args:
@@ -53,7 +53,7 @@ def __init__(self, root,
5353
transform=None, target_transform=None,
5454
download=False):
5555
import pandas
56-
self.root = os.path.expanduser(root)
56+
super(CelebA, self).__init__(root)
5757
self.split = split
5858
if isinstance(target_type, list):
5959
self.target_type = target_type
@@ -158,14 +158,6 @@ def __getitem__(self, index):
158158
def __len__(self):
159159
return len(self.attr)
160160

161-
def __repr__(self):
162-
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
163-
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
164-
fmt_str += ' Target type: {}\n'.format(self.target_type)
165-
fmt_str += ' Split: {}\n'.format(self.split)
166-
fmt_str += ' Root Location: {}\n'.format(self.root)
167-
tmp = ' Transforms (if any): '
168-
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
169-
tmp = ' Target Transforms (if any): '
170-
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
171-
return fmt_str
161+
def extra_repr(self):
162+
lines = ["Target type: {target_type}", "Split: {split}"]
163+
return '\n'.join(lines).format(**self.__dict__)

torchvision/datasets/imagenet.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,25 +132,8 @@ def valid_splits(self):
132132
def split_folder(self):
133133
return os.path.join(self.root, self.split)
134134

135-
def __repr__(self):
136-
head = "Dataset " + self.__class__.__name__
137-
body = ["Number of datapoints: {}".format(self.__len__())]
138-
if self.root is not None:
139-
body.append("Root location: {}".format(self.root))
140-
body += ["Split: {}".format(self.split)]
141-
if hasattr(self, 'transform') and self.transform is not None:
142-
body += self._format_transform_repr(self.transform,
143-
"Transforms: ")
144-
if hasattr(self, 'target_transform') and self.target_transform is not None:
145-
body += self._format_transform_repr(self.target_transform,
146-
"Target transforms: ")
147-
lines = [head] + [" " * 4 + line for line in body]
148-
return '\n'.join(lines)
149-
150-
def _format_transform_repr(self, transform, head):
151-
lines = transform.__repr__().splitlines()
152-
return (["{}{}".format(head, lines[0])] +
153-
["{}{}".format(" " * len(head), line) for line in lines[1:]])
135+
def extra_repr(self):
136+
return "Split: {split}".format(**self.__dict__)
154137

155138

156139
def extract_tar(src, dest=None, gzip=None, delete=False):

torchvision/datasets/sbd.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
import torch.utils.data as data
2+
from .vision import VisionDataset
33

44
import numpy as np
55

@@ -8,7 +8,7 @@
88
from .voc import download_extract
99

1010

11-
class SBDataset(data.Dataset):
11+
class SBDataset(VisionDataset):
1212
"""`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
1313
1414
The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
@@ -62,10 +62,11 @@ def __init__(self,
6262
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
6363
"pip install scipy")
6464

65+
super(SBDataset, self).__init__(root)
66+
6567
if mode not in ("segmentation", "boundaries"):
6668
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")
6769

68-
self.root = os.path.expanduser(root)
6970
self.xy_transform = xy_transform
7071
self.image_set = image_set
7172
self.mode = mode
@@ -121,3 +122,7 @@ def __getitem__(self, index):
121122

122123
def __len__(self):
123124
return len(self.images)
125+
126+
def extra_repr(self):
127+
lines = ["Image set: {image_set}", "Mode: {mode}"]
128+
return '\n'.join(lines).format(**self.__dict__)

0 commit comments

Comments
 (0)