Skip to content

Added support for VisionDataset #838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 8 additions & 33 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,12 @@
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import collections

import torch.utils.data as data
from .utils import download_url, check_integrity, makedir_exist_ok
from .vision import VisionDataset
from .utils import download_url, makedir_exist_ok


class Caltech101(data.Dataset):
class Caltech101(VisionDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.

Args:
Expand All @@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
def __init__(self, root, target_type="category",
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech101")
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
Expand Down Expand Up @@ -138,19 +131,11 @@ def download(self):
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Target type: {target_type}".format(**self.__dict__)


class Caltech256(data.Dataset):
class Caltech256(VisionDataset):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.

Args:
Expand All @@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
def __init__(self, root,
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech256")
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
makedir_exist_ok(self.root)
self.transform = transform
self.target_transform = target_transform
Expand Down Expand Up @@ -233,13 +218,3 @@ def download(self):
# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
20 changes: 6 additions & 14 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import torch.utils.data as data
import os
import PIL
from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity


class CelebA(data.Dataset):
class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.

Args:
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, root,
transform=None, target_transform=None,
download=False):
import pandas
self.root = os.path.expanduser(root)
super(CelebA, self).__init__(root)
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
Expand Down Expand Up @@ -158,14 +158,6 @@ def __getitem__(self, index):
def __len__(self):
return len(self.attr)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
lines = ["Target type: {target_type}", "Split: {split}"]
return '\n'.join(lines).format(**self.__dict__)
21 changes: 2 additions & 19 deletions torchvision/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,25 +132,8 @@ def valid_splits(self):
def split_folder(self):
return os.path.join(self.root, self.split)

def __repr__(self):
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += ["Split: {}".format(self.split)]
if hasattr(self, 'transform') and self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transforms: ")
if hasattr(self, 'target_transform') and self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
lines = [head] + [" " * 4 + line for line in body]
return '\n'.join(lines)

def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)


def extract_tar(src, dest=None, gzip=None, delete=False):
Expand Down
11 changes: 8 additions & 3 deletions torchvision/datasets/sbd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import torch.utils.data as data
from .vision import VisionDataset

import numpy as np

Expand All @@ -8,7 +8,7 @@
from .voc import download_extract


class SBDataset(data.Dataset):
class SBDataset(VisionDataset):
"""`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_

The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
Expand Down Expand Up @@ -62,10 +62,11 @@ def __init__(self,
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy")

super(SBDataset, self).__init__(root)

if mode not in ("segmentation", "boundaries"):
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")

self.root = os.path.expanduser(root)
self.xy_transform = xy_transform
self.image_set = image_set
self.mode = mode
Expand Down Expand Up @@ -121,3 +122,7 @@ def __getitem__(self, index):

def __len__(self):
return len(self.images)

def extra_repr(self):
lines = ["Image set: {image_set}", "Mode: {mode}"]
return '\n'.join(lines).format(**self.__dict__)