Skip to content

Commit 7044049

Browse files
vishwakftwsoumith
authored andcommitted
Add description for Dataset objects (#384)
* add __repr__ for datasets * fix lint
1 parent a8071d5 commit 7044049

File tree

10 files changed

+103
-1
lines changed

10 files changed

+103
-1
lines changed

torchvision/datasets/cifar.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ def download(self):
159159
tar.close()
160160
os.chdir(cwd)
161161

162+
def __repr__(self):
163+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
164+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
165+
tmp = 'train' if self.train is True else 'test'
166+
fmt_str += ' Split: {}\n'.format(tmp)
167+
fmt_str += ' Root Location: {}\n'.format(self.root)
168+
tmp = ' Transforms (if any): '
169+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
170+
tmp = ' Target Transforms (if any): '
171+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
172+
return fmt_str
173+
162174

163175
class CIFAR100(CIFAR10):
164176
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

torchvision/datasets/coco.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,13 @@ def __getitem__(self, index):
125125

126126
def __len__(self):
127127
return len(self.ids)
128+
129+
def __repr__(self):
130+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
131+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
132+
fmt_str += ' Root Location: {}\n'.format(self.root)
133+
tmp = ' Transforms (if any): '
134+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
135+
tmp = ' Target Transforms (if any): '
136+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
137+
return fmt_str

torchvision/datasets/fakedata.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,12 @@ def __getitem__(self, index):
5454

5555
def __len__(self):
5656
return self.size
57+
58+
def __repr__(self):
59+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
60+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
61+
tmp = ' Transforms (if any): '
62+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
63+
tmp = ' Target Transforms (if any): '
64+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
65+
return fmt_str

torchvision/datasets/folder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,13 @@ def __getitem__(self, index):
129129

130130
def __len__(self):
131131
return len(self.imgs)
132+
133+
def __repr__(self):
134+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
135+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
136+
fmt_str += ' Root Location: {}\n'.format(self.root)
137+
tmp = ' Transforms (if any): '
138+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
139+
tmp = ' Target Transforms (if any): '
140+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
141+
return fmt_str

torchvision/datasets/lsun.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,11 @@ def __len__(self):
143143
return self.length
144144

145145
def __repr__(self):
146-
return self.__class__.__name__ + ' (' + self.db_path + ')'
146+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
147+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
148+
fmt_str += ' Root Location: {}\n'.format(self.root)
149+
tmp = ' Transforms (if any): '
150+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
151+
tmp = ' Target Transforms (if any): '
152+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
153+
return fmt_str

torchvision/datasets/mnist.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ def download(self):
139139

140140
print('Done!')
141141

142+
def __repr__(self):
143+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
144+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
145+
tmp = 'train' if self.train is True else 'test'
146+
fmt_str += ' Split: {}\n'.format(tmp)
147+
fmt_str += ' Root Location: {}\n'.format(self.root)
148+
tmp = ' Transforms (if any): '
149+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
150+
tmp = ' Target Transforms (if any): '
151+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
152+
return fmt_str
153+
142154

143155
class FashionMNIST(MNIST):
144156
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

torchvision/datasets/phototour.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,16 @@ def download(self):
152152
with open(self.data_file, 'wb') as f:
153153
torch.save(dataset, f)
154154

155+
def __repr__(self):
156+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
157+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
158+
tmp = 'train' if self.train is True else 'test'
159+
fmt_str += ' Split: {}\n'.format(tmp)
160+
fmt_str += ' Root Location: {}\n'.format(self.root)
161+
tmp = ' Transforms (if any): '
162+
fmt_str += '{0}{1}'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
163+
return fmt_str
164+
155165

156166
def read_image_file(data_dir, image_ext, n):
157167
"""Return a Tensor containing the patches

torchvision/datasets/semeion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,13 @@ def download(self):
9191

9292
root = self.root
9393
download_url(self.url, root, self.filename, self.md5_checksum)
94+
95+
def __repr__(self):
96+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
97+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
98+
fmt_str += ' Root Location: {}\n'.format(self.root)
99+
tmp = ' Transforms (if any): '
100+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
101+
tmp = ' Target Transforms (if any): '
102+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
103+
return fmt_str

torchvision/datasets/stl10.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,14 @@ def __loadfile(self, data_file, labels_file=None):
126126
images = np.transpose(images, (0, 1, 3, 2))
127127

128128
return images, labels
129+
130+
def __repr__(self):
131+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
132+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
133+
fmt_str += ' Split: {}\n'.format(self.split)
134+
fmt_str += ' Root Location: {}\n'.format(self.root)
135+
tmp = ' Transforms (if any): '
136+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
137+
tmp = ' Target Transforms (if any): '
138+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
139+
return fmt_str

torchvision/datasets/svhn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,14 @@ def _check_integrity(self):
115115
def download(self):
116116
md5 = self.split_list[self.split][2]
117117
download_url(self.url, self.root, self.filename, md5)
118+
119+
def __repr__(self):
120+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
121+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
122+
fmt_str += ' Split: {}\n'.format(self.split)
123+
fmt_str += ' Root Location: {}\n'.format(self.root)
124+
tmp = ' Transforms (if any): '
125+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
126+
tmp = ' Target Transforms (if any): '
127+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
128+
return fmt_str

0 commit comments

Comments
 (0)