2
2
from PIL import Image
3
3
import os
4
4
import os .path
5
- import numpy as np
6
- import sys
7
- if sys .version_info [0 ] == 2 :
8
- import cPickle as pickle
9
- else :
10
- import pickle
11
- import collections
12
5
13
- import torch . utils . data as data
14
- from .utils import download_url , check_integrity , makedir_exist_ok
6
+ from . vision import VisionDataset
7
+ from .utils import download_url , makedir_exist_ok
15
8
16
9
17
- class Caltech101 (data . Dataset ):
10
+ class Caltech101 (VisionDataset ):
18
11
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
19
12
20
13
Args:
@@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
36
29
def __init__ (self , root , target_type = "category" ,
37
30
transform = None , target_transform = None ,
38
31
download = False ):
39
- self . root = os .path .join (os . path . expanduser ( root ), " caltech101" )
32
+ super ( Caltech101 , self ). __init__ ( os .path .join (root , ' caltech101' ) )
40
33
makedir_exist_ok (self .root )
41
34
if isinstance (target_type , list ):
42
35
self .target_type = target_type
@@ -138,19 +131,11 @@ def download(self):
138
131
with tarfile .open (os .path .join (self .root , "101_Annotations.tar" ), "r:" ) as tar :
139
132
tar .extractall (path = self .root )
140
133
141
- def __repr__ (self ):
142
- fmt_str = 'Dataset ' + self .__class__ .__name__ + '\n '
143
- fmt_str += ' Number of datapoints: {}\n ' .format (self .__len__ ())
144
- fmt_str += ' Target type: {}\n ' .format (self .target_type )
145
- fmt_str += ' Root Location: {}\n ' .format (self .root )
146
- tmp = ' Transforms (if any): '
147
- fmt_str += '{0}{1}\n ' .format (tmp , self .transform .__repr__ ().replace ('\n ' , '\n ' + ' ' * len (tmp )))
148
- tmp = ' Target Transforms (if any): '
149
- fmt_str += '{0}{1}' .format (tmp , self .target_transform .__repr__ ().replace ('\n ' , '\n ' + ' ' * len (tmp )))
150
- return fmt_str
134
+ def extra_repr (self ):
135
+ return "Target type: {target_type}" .format (** self .__dict__ )
151
136
152
137
153
- class Caltech256 (data . Dataset ):
138
+ class Caltech256 (VisionDataset ):
154
139
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
155
140
156
141
Args:
@@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
168
153
def __init__ (self , root ,
169
154
transform = None , target_transform = None ,
170
155
download = False ):
171
- self . root = os .path .join (os . path . expanduser ( root ), " caltech256" )
156
+ super ( Caltech256 , self ). __init__ ( os .path .join (root , ' caltech256' ) )
172
157
makedir_exist_ok (self .root )
173
158
self .transform = transform
174
159
self .target_transform = target_transform
@@ -233,13 +218,3 @@ def download(self):
233
218
# extract file
234
219
with tarfile .open (os .path .join (self .root , "256_ObjectCategories.tar" ), "r:" ) as tar :
235
220
tar .extractall (path = self .root )
236
-
237
- def __repr__ (self ):
238
- fmt_str = 'Dataset ' + self .__class__ .__name__ + '\n '
239
- fmt_str += ' Number of datapoints: {}\n ' .format (self .__len__ ())
240
- fmt_str += ' Root Location: {}\n ' .format (self .root )
241
- tmp = ' Transforms (if any): '
242
- fmt_str += '{0}{1}\n ' .format (tmp , self .transform .__repr__ ().replace ('\n ' , '\n ' + ' ' * len (tmp )))
243
- tmp = ' Target Transforms (if any): '
244
- fmt_str += '{0}{1}' .format (tmp , self .target_transform .__repr__ ().replace ('\n ' , '\n ' + ' ' * len (tmp )))
245
- return fmt_str
0 commit comments